Merge branch 'main' into generic_configuration

# Conflicts:
#	src/field/extension_field/mod.rs
#	src/fri/recursive_verifier.rs
#	src/gadgets/arithmetic.rs
#	src/gadgets/arithmetic_extension.rs
#	src/gadgets/hash.rs
#	src/gadgets/interpolation.rs
#	src/gadgets/random_access.rs
#	src/gadgets/sorting.rs
#	src/gates/arithmetic_u32.rs
#	src/gates/gate_tree.rs
#	src/gates/interpolation.rs
#	src/gates/poseidon.rs
#	src/gates/poseidon_mds.rs
#	src/gates/random_access.rs
#	src/hash/hashing.rs
#	src/hash/merkle_proofs.rs
#	src/hash/poseidon.rs
#	src/iop/challenger.rs
#	src/iop/generator.rs
#	src/iop/witness.rs
#	src/plonk/circuit_data.rs
#	src/plonk/proof.rs
#	src/plonk/prover.rs
#	src/plonk/recursive_verifier.rs
#	src/util/partial_products.rs
#	src/util/reducing.rs
This commit is contained in:
wborgeaud 2021-12-16 14:54:38 +01:00
commit bdbc8b6931
105 changed files with 8303 additions and 2545 deletions

View File

@ -28,9 +28,11 @@ jobs:
with: with:
command: test command: test
args: --all args: --all
env:
RUSTFLAGS: -Copt-level=3 -Cdebug-assertions -Coverflow-checks=y
lints: lints:
name: Formatting name: Formatting and Clippy
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: "! contains(toJSON(github.event.commits.*.message), '[skip-ci]')" if: "! contains(toJSON(github.event.commits.*.message), '[skip-ci]')"
steps: steps:
@ -43,10 +45,17 @@ jobs:
profile: minimal profile: minimal
toolchain: nightly toolchain: nightly
override: true override: true
components: rustfmt components: rustfmt, clippy
- name: Run cargo fmt - name: Run cargo fmt
uses: actions-rs/cargo@v1 uses: actions-rs/cargo@v1
with: with:
command: fmt command: fmt
args: --all -- --check args: --all -- --check
- name: Run cargo clippy
uses: actions-rs/cargo@v1
with:
command: clippy
args: --all-features --all-targets -- -D warnings -A incomplete-features

View File

@ -8,7 +8,7 @@ license = "MIT OR Apache-2.0"
repository = "https://github.com/mir-protocol/plonky2" repository = "https://github.com/mir-protocol/plonky2"
keywords = ["cryptography", "SNARK", "FRI"] keywords = ["cryptography", "SNARK", "FRI"]
categories = ["cryptography"] categories = ["cryptography"]
edition = "2018" edition = "2021"
default-run = "bench_recursion" default-run = "bench_recursion"
[dependencies] [dependencies]
@ -28,6 +28,9 @@ serde_cbor = "0.11.1"
keccak-hash = "0.8.0" keccak-hash = "0.8.0"
static_assertions = "1.1.0" static_assertions = "1.1.0"
[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.3.2"
[dev-dependencies] [dev-dependencies]
criterion = "0.3.5" criterion = "0.3.5"
tynm = "0.1.6" tynm = "0.1.6"

View File

@ -1,7 +1,7 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use plonky2::field::field_types::Field; use plonky2::field::field_types::Field;
use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::polynomial::polynomial::PolynomialCoeffs; use plonky2::polynomial::PolynomialCoeffs;
use tynm::type_name; use tynm::type_name;
pub(crate) fn bench_ffts<F: Field>(c: &mut Criterion) { pub(crate) fn bench_ffts<F: Field>(c: &mut Criterion) {

View File

@ -1,5 +1,3 @@
#![feature(destructuring_assignment)]
use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use plonky2::field::extension_field::quartic::QuarticExtension; use plonky2::field::extension_field::quartic::QuarticExtension;
use plonky2::field::field_types::Field; use plonky2::field::field_types::Field;
@ -112,6 +110,66 @@ pub(crate) fn bench_field<F: Field>(c: &mut Criterion) {
c.bench_function(&format!("try_inverse<{}>", type_name::<F>()), |b| { c.bench_function(&format!("try_inverse<{}>", type_name::<F>()), |b| {
b.iter_batched(|| F::rand(), |x| x.try_inverse(), BatchSize::SmallInput) b.iter_batched(|| F::rand(), |x| x.try_inverse(), BatchSize::SmallInput)
}); });
c.bench_function(
&format!("batch_multiplicative_inverse-tiny<{}>", type_name::<F>()),
|b| {
b.iter_batched(
|| (0..2).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|x| F::batch_multiplicative_inverse(&x),
BatchSize::SmallInput,
)
},
);
c.bench_function(
&format!("batch_multiplicative_inverse-small<{}>", type_name::<F>()),
|b| {
b.iter_batched(
|| (0..4).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|x| F::batch_multiplicative_inverse(&x),
BatchSize::SmallInput,
)
},
);
c.bench_function(
&format!("batch_multiplicative_inverse-medium<{}>", type_name::<F>()),
|b| {
b.iter_batched(
|| (0..16).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|x| F::batch_multiplicative_inverse(&x),
BatchSize::SmallInput,
)
},
);
c.bench_function(
&format!("batch_multiplicative_inverse-large<{}>", type_name::<F>()),
|b| {
b.iter_batched(
|| (0..256).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|x| F::batch_multiplicative_inverse(&x),
BatchSize::LargeInput,
)
},
);
c.bench_function(
&format!("batch_multiplicative_inverse-huge<{}>", type_name::<F>()),
|b| {
b.iter_batched(
|| {
(0..65536)
.into_iter()
.map(|_| F::rand())
.collect::<Vec<_>>()
},
|x| F::batch_multiplicative_inverse(&x),
BatchSize::LargeInput,
)
},
);
} }
fn criterion_benchmark(c: &mut Criterion) { fn criterion_benchmark(c: &mut Criterion) {

View File

@ -1,4 +1,3 @@
#![feature(destructuring_assignment)]
#![feature(generic_const_exprs)] #![feature(generic_const_exprs)]
use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; use criterion::{criterion_group, criterion_main, BatchSize, Criterion};

View File

@ -2,7 +2,7 @@ use std::time::Instant;
use plonky2::field::field_types::Field; use plonky2::field::field_types::Field;
use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::polynomial::polynomial::PolynomialValues; use plonky2::polynomial::PolynomialValues;
use rayon::prelude::*; use rayon::prelude::*;
type F = GoldilocksField; type F = GoldilocksField;

View File

@ -24,6 +24,7 @@ fn bench_prove<C: GenericConfig<D>, const D: usize>() -> Result<()> {
num_wires: 126, num_wires: 126,
num_routed_wires: 33, num_routed_wires: 33,
constant_gate_size: 6, constant_gate_size: 6,
use_base_arithmetic_gate: false,
security_bits: 128, security_bits: 128,
rate_bits: 3, rate_bits: 3,
num_challenges: 3, num_challenges: 3,

View File

@ -1,5 +1,7 @@
//! Generates random constants using ChaCha20, seeded with zero. //! Generates random constants using ChaCha20, seeded with zero.
#![allow(clippy::needless_range_loop)]
use plonky2::field::field_types::PrimeField; use plonky2::field::field_types::PrimeField;
use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::goldilocks_field::GoldilocksField;
use rand::{Rng, SeedableRng}; use rand::{Rng, SeedableRng};

156
src/curve/curve_adds.rs Normal file
View File

@ -0,0 +1,156 @@
use std::ops::Add;
use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint};
use crate::field::field_types::Field;
impl<C: Curve> Add<ProjectivePoint<C>> for ProjectivePoint<C> {
type Output = ProjectivePoint<C>;
fn add(self, rhs: ProjectivePoint<C>) -> Self::Output {
let ProjectivePoint {
x: x1,
y: y1,
z: z1,
} = self;
let ProjectivePoint {
x: x2,
y: y2,
z: z2,
} = rhs;
if z1 == C::BaseField::ZERO {
return rhs;
}
if z2 == C::BaseField::ZERO {
return self;
}
let x1z2 = x1 * z2;
let y1z2 = y1 * z2;
let x2z1 = x2 * z1;
let y2z1 = y2 * z1;
// Check if we're doubling or adding inverses.
if x1z2 == x2z1 {
if y1z2 == y2z1 {
// TODO: inline to avoid redundant muls.
return self.double();
}
if y1z2 == -y2z1 {
return ProjectivePoint::ZERO;
}
}
// From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/add-1998-cmo-2
let z1z2 = z1 * z2;
let u = y2z1 - y1z2;
let uu = u.square();
let v = x2z1 - x1z2;
let vv = v.square();
let vvv = v * vv;
let r = vv * x1z2;
let a = uu * z1z2 - vvv - r.double();
let x3 = v * a;
let y3 = u * (r - a) - vvv * y1z2;
let z3 = vvv * z1z2;
ProjectivePoint::nonzero(x3, y3, z3)
}
}
impl<C: Curve> Add<AffinePoint<C>> for ProjectivePoint<C> {
type Output = ProjectivePoint<C>;
fn add(self, rhs: AffinePoint<C>) -> Self::Output {
let ProjectivePoint {
x: x1,
y: y1,
z: z1,
} = self;
let AffinePoint {
x: x2,
y: y2,
zero: zero2,
} = rhs;
if z1 == C::BaseField::ZERO {
return rhs.to_projective();
}
if zero2 {
return self;
}
let x2z1 = x2 * z1;
let y2z1 = y2 * z1;
// Check if we're doubling or adding inverses.
if x1 == x2z1 {
if y1 == y2z1 {
// TODO: inline to avoid redundant muls.
return self.double();
}
if y1 == -y2z1 {
return ProjectivePoint::ZERO;
}
}
// From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/madd-1998-cmo
let u = y2z1 - y1;
let uu = u.square();
let v = x2z1 - x1;
let vv = v.square();
let vvv = v * vv;
let r = vv * x1;
let a = uu * z1 - vvv - r.double();
let x3 = v * a;
let y3 = u * (r - a) - vvv * y1;
let z3 = vvv * z1;
ProjectivePoint::nonzero(x3, y3, z3)
}
}
impl<C: Curve> Add<AffinePoint<C>> for AffinePoint<C> {
type Output = ProjectivePoint<C>;
fn add(self, rhs: AffinePoint<C>) -> Self::Output {
let AffinePoint {
x: x1,
y: y1,
zero: zero1,
} = self;
let AffinePoint {
x: x2,
y: y2,
zero: zero2,
} = rhs;
if zero1 {
return rhs.to_projective();
}
if zero2 {
return self.to_projective();
}
// Check if we're doubling or adding inverses.
if x1 == x2 {
if y1 == y2 {
return self.to_projective().double();
}
if y1 == -y2 {
return ProjectivePoint::ZERO;
}
}
// From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/mmadd-1998-cmo
let u = y2 - y1;
let uu = u.square();
let v = x2 - x1;
let vv = v.square();
let vvv = v * vv;
let r = vv * x1;
let a = uu - vvv - r.double();
let x3 = v * a;
let y3 = u * (r - a) - vvv * y1;
let z3 = vvv;
ProjectivePoint::nonzero(x3, y3, z3)
}
}

263
src/curve/curve_msm.rs Normal file
View File

@ -0,0 +1,263 @@
use itertools::Itertools;
use rayon::prelude::*;
use crate::curve::curve_summation::affine_multisummation_best;
use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint};
use crate::field::field_types::Field;
/// In Yao's method, we compute an affine summation for each digit. In a parallel setting, it would
/// be easiest to assign individual summations to threads, but this would be sub-optimal because
/// multi-summations can be more efficient than repeating individual summations (see
/// `affine_multisummation_best`). Thus we divide digits into large chunks, and assign chunks of
/// digits to threads. Note that there is a delicate balance here, as large chunks can result in
/// uneven distributions of work among threads.
const DIGITS_PER_CHUNK: usize = 80;
#[derive(Clone, Debug)]
pub struct MsmPrecomputation<C: Curve> {
/// For each generator (in the order they were passed to `msm_precompute`), contains a vector
/// of powers, i.e. [(2^w)^i] for i < DIGITS.
// TODO: Use compressed coordinates here.
powers_per_generator: Vec<Vec<AffinePoint<C>>>,
/// The window size.
w: usize,
}
pub fn msm_precompute<C: Curve>(
generators: &[ProjectivePoint<C>],
w: usize,
) -> MsmPrecomputation<C> {
MsmPrecomputation {
powers_per_generator: generators
.into_par_iter()
.map(|&g| precompute_single_generator(g, w))
.collect(),
w,
}
}
fn precompute_single_generator<C: Curve>(g: ProjectivePoint<C>, w: usize) -> Vec<AffinePoint<C>> {
let digits = (C::ScalarField::BITS + w - 1) / w;
let mut powers: Vec<ProjectivePoint<C>> = Vec::with_capacity(digits);
powers.push(g);
for i in 1..digits {
let mut power_i_proj = powers[i - 1];
for _j in 0..w {
power_i_proj = power_i_proj.double();
}
powers.push(power_i_proj);
}
ProjectivePoint::batch_to_affine(&powers)
}
pub fn msm_parallel<C: Curve>(
scalars: &[C::ScalarField],
generators: &[ProjectivePoint<C>],
w: usize,
) -> ProjectivePoint<C> {
let precomputation = msm_precompute(generators, w);
msm_execute_parallel(&precomputation, scalars)
}
pub fn msm_execute<C: Curve>(
precomputation: &MsmPrecomputation<C>,
scalars: &[C::ScalarField],
) -> ProjectivePoint<C> {
assert_eq!(precomputation.powers_per_generator.len(), scalars.len());
let w = precomputation.w;
let digits = (C::ScalarField::BITS + w - 1) / w;
let base = 1 << w;
// This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use
// extremely large windows, the repeated scans in Yao's method could be more expensive than the
// actual group operations. To avoid this, we store a multimap from each possible digit to the
// positions in which that digit occurs in the scalars. These positions have the form (i, j),
// where i is the index of the generator and j is an index into the digits of the scalar
// associated with that generator.
let mut digit_occurrences: Vec<Vec<(usize, usize)>> = Vec::with_capacity(digits);
for _i in 0..base {
digit_occurrences.push(Vec::new());
}
for (i, scalar) in scalars.iter().enumerate() {
let digits = to_digits::<C>(scalar, w);
for (j, &digit) in digits.iter().enumerate() {
digit_occurrences[digit].push((i, j));
}
}
let mut y = ProjectivePoint::ZERO;
let mut u = ProjectivePoint::ZERO;
for digit in (1..base).rev() {
for &(i, j) in &digit_occurrences[digit] {
u = u + precomputation.powers_per_generator[i][j];
}
y = y + u;
}
y
}
pub fn msm_execute_parallel<C: Curve>(
precomputation: &MsmPrecomputation<C>,
scalars: &[C::ScalarField],
) -> ProjectivePoint<C> {
assert_eq!(precomputation.powers_per_generator.len(), scalars.len());
let w = precomputation.w;
let digits = (C::ScalarField::BITS + w - 1) / w;
let base = 1 << w;
// This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use
// extremely large windows, the repeated scans in Yao's method could be more expensive than the
// actual group operations. To avoid this, we store a multimap from each possible digit to the
// positions in which that digit occurs in the scalars. These positions have the form (i, j),
// where i is the index of the generator and j is an index into the digits of the scalar
// associated with that generator.
let mut digit_occurrences: Vec<Vec<(usize, usize)>> = Vec::with_capacity(digits);
for _i in 0..base {
digit_occurrences.push(Vec::new());
}
for (i, scalar) in scalars.iter().enumerate() {
let digits = to_digits::<C>(scalar, w);
for (j, &digit) in digits.iter().enumerate() {
digit_occurrences[digit].push((i, j));
}
}
// For each digit, we add up the powers associated with all occurrences that digit.
let digits: Vec<usize> = (0..base).collect();
let digit_acc: Vec<ProjectivePoint<C>> = digits
.par_chunks(DIGITS_PER_CHUNK)
.flat_map(|chunk| {
let summations: Vec<Vec<AffinePoint<C>>> = chunk
.iter()
.map(|&digit| {
digit_occurrences[digit]
.iter()
.map(|&(i, j)| precomputation.powers_per_generator[i][j])
.collect()
})
.collect();
affine_multisummation_best(summations)
})
.collect();
// println!("Computing the per-digit summations (in parallel) took {}s", start.elapsed().as_secs_f64());
let mut y = ProjectivePoint::ZERO;
let mut u = ProjectivePoint::ZERO;
for digit in (1..base).rev() {
u = u + digit_acc[digit];
y = y + u;
}
// println!("Final summation (sequential) {}s", start.elapsed().as_secs_f64());
y
}
pub(crate) fn to_digits<C: Curve>(x: &C::ScalarField, w: usize) -> Vec<usize> {
let scalar_bits = C::ScalarField::BITS;
let num_digits = (scalar_bits + w - 1) / w;
// Convert x to a bool array.
let x_canonical: Vec<_> = x
.to_biguint()
.to_u64_digits()
.iter()
.cloned()
.pad_using(scalar_bits / 64, |_| 0)
.collect();
let mut x_bits = Vec::with_capacity(scalar_bits);
for i in 0..scalar_bits {
x_bits.push((x_canonical[i / 64] >> (i as u64 % 64) & 1) != 0);
}
let mut digits = Vec::with_capacity(num_digits);
for i in 0..num_digits {
let mut digit = 0;
for j in ((i * w)..((i + 1) * w).min(scalar_bits)).rev() {
digit <<= 1;
digit |= x_bits[j] as usize;
}
digits.push(digit);
}
digits
}
#[cfg(test)]
mod tests {
use num::BigUint;
use crate::curve::curve_msm::{msm_execute, msm_precompute, to_digits};
use crate::curve::curve_types::Curve;
use crate::curve::secp256k1::Secp256K1;
use crate::field::field_types::Field;
use crate::field::secp256k1_scalar::Secp256K1Scalar;
#[test]
fn test_to_digits() {
let x_canonical = [
0b10101010101010101010101010101010,
0b10101010101010101010101010101010,
0b11001100110011001100110011001100,
0b11001100110011001100110011001100,
0b11110000111100001111000011110000,
0b11110000111100001111000011110000,
0b00001111111111111111111111111111,
0b11111111111111111111111111111111,
];
let x = Secp256K1Scalar::from_biguint(BigUint::from_slice(&x_canonical));
assert_eq!(x.to_biguint().to_u32_digits(), x_canonical);
assert_eq!(
to_digits::<Secp256K1>(&x, 17),
vec![
0b01010101010101010,
0b10101010101010101,
0b01010101010101010,
0b11001010101010101,
0b01100110011001100,
0b00110011001100110,
0b10011001100110011,
0b11110000110011001,
0b01111000011110000,
0b00111100001111000,
0b00011110000111100,
0b11111111111111110,
0b01111111111111111,
0b11111111111111000,
0b11111111111111111,
0b1,
]
);
}
#[test]
fn test_msm() {
let w = 5;
let generator_1 = Secp256K1::GENERATOR_PROJECTIVE;
let generator_2 = generator_1 + generator_1;
let generator_3 = generator_1 + generator_2;
let scalar_1 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[
11111111, 22222222, 33333333, 44444444,
]));
let scalar_2 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[
22222222, 22222222, 33333333, 44444444,
]));
let scalar_3 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[
33333333, 22222222, 33333333, 44444444,
]));
let generators = vec![generator_1, generator_2, generator_3];
let scalars = vec![scalar_1, scalar_2, scalar_3];
let precomputation = msm_precompute(&generators, w);
let result_msm = msm_execute(&precomputation, &scalars);
let result_naive = Secp256K1::convert(scalar_1) * generator_1
+ Secp256K1::convert(scalar_2) * generator_2
+ Secp256K1::convert(scalar_3) * generator_3;
assert_eq!(result_msm, result_naive);
}
}

View File

@ -0,0 +1,97 @@
use std::ops::Mul;
use crate::curve::curve_types::{Curve, CurveScalar, ProjectivePoint};
use crate::field::field_types::Field;
const WINDOW_BITS: usize = 4;
const BASE: usize = 1 << WINDOW_BITS;
fn digits_per_scalar<C: Curve>() -> usize {
(C::ScalarField::BITS + WINDOW_BITS - 1) / WINDOW_BITS
}
/// Precomputed state used for scalar x ProjectivePoint multiplications,
/// specific to a particular generator.
#[derive(Clone)]
pub struct MultiplicationPrecomputation<C: Curve> {
/// [(2^w)^i] g for each i < digits_per_scalar.
powers: Vec<ProjectivePoint<C>>,
}
impl<C: Curve> ProjectivePoint<C> {
pub fn mul_precompute(&self) -> MultiplicationPrecomputation<C> {
let num_digits = digits_per_scalar::<C>();
let mut powers = Vec::with_capacity(num_digits);
powers.push(*self);
for i in 1..num_digits {
let mut power_i = powers[i - 1];
for _j in 0..WINDOW_BITS {
power_i = power_i.double();
}
powers.push(power_i);
}
MultiplicationPrecomputation { powers }
}
pub fn mul_with_precomputation(
&self,
scalar: C::ScalarField,
precomputation: MultiplicationPrecomputation<C>,
) -> Self {
// Yao's method; see https://koclab.cs.ucsb.edu/teaching/ecc/eccPapers/Doche-ch09.pdf
let precomputed_powers = precomputation.powers;
let digits = to_digits::<C>(&scalar);
let mut y = ProjectivePoint::ZERO;
let mut u = ProjectivePoint::ZERO;
let mut all_summands = Vec::new();
for j in (1..BASE).rev() {
let mut u_summands = Vec::new();
for (i, &digit) in digits.iter().enumerate() {
if digit == j as u64 {
u_summands.push(precomputed_powers[i]);
}
}
all_summands.push(u_summands);
}
let all_sums: Vec<ProjectivePoint<C>> = all_summands
.iter()
.cloned()
.map(|vec| vec.iter().fold(ProjectivePoint::ZERO, |a, &b| a + b))
.collect();
for i in 0..all_sums.len() {
u = u + all_sums[i];
y = y + u;
}
y
}
}
impl<C: Curve> Mul<ProjectivePoint<C>> for CurveScalar<C> {
type Output = ProjectivePoint<C>;
fn mul(self, rhs: ProjectivePoint<C>) -> Self::Output {
let precomputation = rhs.mul_precompute();
rhs.mul_with_precomputation(self.0, precomputation)
}
}
#[allow(clippy::assertions_on_constants)]
fn to_digits<C: Curve>(x: &C::ScalarField) -> Vec<u64> {
debug_assert!(
64 % WINDOW_BITS == 0,
"For simplicity, only power-of-two window sizes are handled for now"
);
let digits_per_u64 = 64 / WINDOW_BITS;
let mut digits = Vec::with_capacity(digits_per_scalar::<C>());
for limb in x.to_biguint().to_u64_digits() {
for j in 0..digits_per_u64 {
digits.push((limb >> (j * WINDOW_BITS) as u64) % BASE as u64);
}
}
digits
}

View File

@ -0,0 +1,237 @@
use std::iter::Sum;
use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint};
use crate::field::field_types::Field;
impl<C: Curve> Sum<AffinePoint<C>> for ProjectivePoint<C> {
fn sum<I: Iterator<Item = AffinePoint<C>>>(iter: I) -> ProjectivePoint<C> {
let points: Vec<_> = iter.collect();
affine_summation_best(points)
}
}
impl<C: Curve> Sum for ProjectivePoint<C> {
fn sum<I: Iterator<Item = ProjectivePoint<C>>>(iter: I) -> ProjectivePoint<C> {
iter.fold(ProjectivePoint::ZERO, |acc, x| acc + x)
}
}
pub fn affine_summation_best<C: Curve>(summation: Vec<AffinePoint<C>>) -> ProjectivePoint<C> {
let result = affine_multisummation_best(vec![summation]);
debug_assert_eq!(result.len(), 1);
result[0]
}
pub fn affine_multisummation_best<C: Curve>(
summations: Vec<Vec<AffinePoint<C>>>,
) -> Vec<ProjectivePoint<C>> {
let pairwise_sums: usize = summations.iter().map(|summation| summation.len() / 2).sum();
// This threshold is chosen based on data from the summation benchmarks.
if pairwise_sums < 70 {
affine_multisummation_pairwise(summations)
} else {
affine_multisummation_batch_inversion(summations)
}
}
/// Adds each pair of points using an affine + affine = projective formula, then adds up the
/// intermediate sums using a projective formula.
pub fn affine_multisummation_pairwise<C: Curve>(
summations: Vec<Vec<AffinePoint<C>>>,
) -> Vec<ProjectivePoint<C>> {
summations
.into_iter()
.map(affine_summation_pairwise)
.collect()
}
/// Adds each pair of points using an affine + affine = projective formula, then adds up the
/// intermediate sums using a projective formula.
pub fn affine_summation_pairwise<C: Curve>(points: Vec<AffinePoint<C>>) -> ProjectivePoint<C> {
let mut reduced_points: Vec<ProjectivePoint<C>> = Vec::new();
for chunk in points.chunks(2) {
match chunk.len() {
1 => reduced_points.push(chunk[0].to_projective()),
2 => reduced_points.push(chunk[0] + chunk[1]),
_ => panic!(),
}
}
// TODO: Avoid copying (deref)
reduced_points
.iter()
.fold(ProjectivePoint::ZERO, |sum, x| sum + *x)
}
/// Computes several summations of affine points by applying an affine group law, except that the
/// divisions are batched via Montgomery's trick.
pub fn affine_summation_batch_inversion<C: Curve>(
summation: Vec<AffinePoint<C>>,
) -> ProjectivePoint<C> {
let result = affine_multisummation_batch_inversion(vec![summation]);
debug_assert_eq!(result.len(), 1);
result[0]
}
/// Computes several summations of affine points by applying an affine group law, except that the
/// divisions are batched via Montgomery's trick.
pub fn affine_multisummation_batch_inversion<C: Curve>(
summations: Vec<Vec<AffinePoint<C>>>,
) -> Vec<ProjectivePoint<C>> {
let mut elements_to_invert = Vec::new();
// For each pair of points, (x1, y1) and (x2, y2), that we're going to add later, we want to
// invert either y (if the points are equal) or x1 - x2 (otherwise). We will use these later.
for summation in &summations {
let n = summation.len();
// The special case for n=0 is to avoid underflow.
let range_end = if n == 0 { 0 } else { n - 1 };
for i in (0..range_end).step_by(2) {
let p1 = summation[i];
let p2 = summation[i + 1];
let AffinePoint {
x: x1,
y: y1,
zero: zero1,
} = p1;
let AffinePoint {
x: x2,
y: _y2,
zero: zero2,
} = p2;
if zero1 || zero2 || p1 == -p2 {
// These are trivial cases where we won't need any inverse.
} else if p1 == p2 {
elements_to_invert.push(y1.double());
} else {
elements_to_invert.push(x1 - x2);
}
}
}
let inverses: Vec<C::BaseField> =
C::BaseField::batch_multiplicative_inverse(&elements_to_invert);
let mut all_reduced_points = Vec::with_capacity(summations.len());
let mut inverse_index = 0;
for summation in summations {
let n = summation.len();
let mut reduced_points = Vec::with_capacity((n + 1) / 2);
// The special case for n=0 is to avoid underflow.
let range_end = if n == 0 { 0 } else { n - 1 };
for i in (0..range_end).step_by(2) {
let p1 = summation[i];
let p2 = summation[i + 1];
let AffinePoint {
x: x1,
y: y1,
zero: zero1,
} = p1;
let AffinePoint {
x: x2,
y: y2,
zero: zero2,
} = p2;
let sum = if zero1 {
p2
} else if zero2 {
p1
} else if p1 == -p2 {
AffinePoint::ZERO
} else {
// It's a non-trivial case where we need one of the inverses we computed earlier.
let inverse = inverses[inverse_index];
inverse_index += 1;
if p1 == p2 {
// This is the doubling case.
let mut numerator = x1.square().triple();
if C::A.is_nonzero() {
numerator += C::A;
}
let quotient = numerator * inverse;
let x3 = quotient.square() - x1.double();
let y3 = quotient * (x1 - x3) - y1;
AffinePoint::nonzero(x3, y3)
} else {
// This is the general case. We use the incomplete addition formulas 4.3 and 4.4.
let quotient = (y1 - y2) * inverse;
let x3 = quotient.square() - x1 - x2;
let y3 = quotient * (x1 - x3) - y1;
AffinePoint::nonzero(x3, y3)
}
};
reduced_points.push(sum);
}
// If n is odd, the last point was not part of a pair.
if n % 2 == 1 {
reduced_points.push(summation[n - 1]);
}
all_reduced_points.push(reduced_points);
}
// We should have consumed all of the inverses from the batch computation.
debug_assert_eq!(inverse_index, inverses.len());
// Recurse with our smaller set of points.
affine_multisummation_best(all_reduced_points)
}
#[cfg(test)]
mod tests {
use crate::curve::curve_summation::{
affine_summation_batch_inversion, affine_summation_pairwise,
};
use crate::curve::curve_types::{Curve, ProjectivePoint};
use crate::curve::secp256k1::Secp256K1;
#[test]
fn test_pairwise_affine_summation() {
let g_affine = Secp256K1::GENERATOR_AFFINE;
let g2_affine = (g_affine + g_affine).to_affine();
let g3_affine = (g_affine + g_affine + g_affine).to_affine();
let g2_proj = g2_affine.to_projective();
let g3_proj = g3_affine.to_projective();
assert_eq!(
affine_summation_pairwise::<Secp256K1>(vec![g_affine, g_affine]),
g2_proj
);
assert_eq!(
affine_summation_pairwise::<Secp256K1>(vec![g_affine, g2_affine]),
g3_proj
);
assert_eq!(
affine_summation_pairwise::<Secp256K1>(vec![g_affine, g_affine, g_affine]),
g3_proj
);
assert_eq!(
affine_summation_pairwise::<Secp256K1>(vec![]),
ProjectivePoint::ZERO
);
}
#[test]
fn test_pairwise_affine_summation_batch_inversion() {
let g = Secp256K1::GENERATOR_AFFINE;
let g_proj = g.to_projective();
assert_eq!(
affine_summation_batch_inversion::<Secp256K1>(vec![g, g]),
g_proj + g_proj
);
assert_eq!(
affine_summation_batch_inversion::<Secp256K1>(vec![g, g, g]),
g_proj + g_proj + g_proj
);
assert_eq!(
affine_summation_batch_inversion::<Secp256K1>(vec![]),
ProjectivePoint::ZERO
);
}
}

260
src/curve/curve_types.rs Normal file
View File

@ -0,0 +1,260 @@
use std::fmt::Debug;
use std::ops::Neg;
use crate::field::field_types::Field;
// To avoid implementation conflicts from associated types,
// see https://github.com/rust-lang/rust/issues/20400
pub struct CurveScalar<C: Curve>(pub <C as Curve>::ScalarField);
/// A short Weierstrass curve.
pub trait Curve: 'static + Sync + Sized + Copy + Debug {
type BaseField: Field;
type ScalarField: Field;
const A: Self::BaseField;
const B: Self::BaseField;
const GENERATOR_AFFINE: AffinePoint<Self>;
const GENERATOR_PROJECTIVE: ProjectivePoint<Self> = ProjectivePoint {
x: Self::GENERATOR_AFFINE.x,
y: Self::GENERATOR_AFFINE.y,
z: Self::BaseField::ONE,
};
fn convert(x: Self::ScalarField) -> CurveScalar<Self> {
CurveScalar(x)
}
fn is_safe_curve() -> bool {
// Added additional check to prevent using vulnerabilties in case a discriminant is equal to 0.
(Self::A.cube().double().double() + Self::B.square().triple().triple().triple())
.is_nonzero()
}
}
/// A point on a short Weierstrass curve, represented in affine coordinates.
#[derive(Copy, Clone, Debug)]
pub struct AffinePoint<C: Curve> {
pub x: C::BaseField,
pub y: C::BaseField,
pub zero: bool,
}
impl<C: Curve> AffinePoint<C> {
pub const ZERO: Self = Self {
x: C::BaseField::ZERO,
y: C::BaseField::ZERO,
zero: true,
};
pub fn nonzero(x: C::BaseField, y: C::BaseField) -> Self {
let point = Self { x, y, zero: false };
debug_assert!(point.is_valid());
point
}
pub fn is_valid(&self) -> bool {
let Self { x, y, zero } = *self;
zero || y.square() == x.cube() + C::A * x + C::B
}
pub fn to_projective(&self) -> ProjectivePoint<C> {
let Self { x, y, zero } = *self;
let z = if zero {
C::BaseField::ZERO
} else {
C::BaseField::ONE
};
ProjectivePoint { x, y, z }
}
pub fn batch_to_projective(affine_points: &[Self]) -> Vec<ProjectivePoint<C>> {
affine_points.iter().map(Self::to_projective).collect()
}
pub fn double(&self) -> Self {
let AffinePoint { x: x1, y: y1, zero } = *self;
if zero {
return AffinePoint::ZERO;
}
let double_y = y1.double();
let inv_double_y = double_y.inverse(); // (2y)^(-1)
let triple_xx = x1.square().triple(); // 3x^2
let lambda = (triple_xx + C::A) * inv_double_y;
let x3 = lambda.square() - self.x.double();
let y3 = lambda * (x1 - x3) - y1;
Self {
x: x3,
y: y3,
zero: false,
}
}
}
impl<C: Curve> PartialEq for AffinePoint<C> {
fn eq(&self, other: &Self) -> bool {
let AffinePoint {
x: x1,
y: y1,
zero: zero1,
} = *self;
let AffinePoint {
x: x2,
y: y2,
zero: zero2,
} = *other;
if zero1 || zero2 {
return zero1 == zero2;
}
x1 == x2 && y1 == y2
}
}
impl<C: Curve> Eq for AffinePoint<C> {}
/// A point on a short Weierstrass curve, represented in projective coordinates.
#[derive(Copy, Clone, Debug)]
pub struct ProjectivePoint<C: Curve> {
pub x: C::BaseField,
pub y: C::BaseField,
pub z: C::BaseField,
}
impl<C: Curve> ProjectivePoint<C> {
pub const ZERO: Self = Self {
x: C::BaseField::ZERO,
y: C::BaseField::ONE,
z: C::BaseField::ZERO,
};
pub fn nonzero(x: C::BaseField, y: C::BaseField, z: C::BaseField) -> Self {
let point = Self { x, y, z };
debug_assert!(point.is_valid());
point
}
pub fn is_valid(&self) -> bool {
let Self { x, y, z } = *self;
z.is_zero() || y.square() * z == x.cube() + C::A * x * z.square() + C::B * z.cube()
}
pub fn to_affine(&self) -> AffinePoint<C> {
let Self { x, y, z } = *self;
if z == C::BaseField::ZERO {
AffinePoint::ZERO
} else {
let z_inv = z.inverse();
AffinePoint::nonzero(x * z_inv, y * z_inv)
}
}
pub fn batch_to_affine(proj_points: &[Self]) -> Vec<AffinePoint<C>> {
let n = proj_points.len();
let zs: Vec<C::BaseField> = proj_points.iter().map(|pp| pp.z).collect();
let z_invs = C::BaseField::batch_multiplicative_inverse(&zs);
let mut result = Vec::with_capacity(n);
for i in 0..n {
let Self { x, y, z } = proj_points[i];
result.push(if z == C::BaseField::ZERO {
AffinePoint::ZERO
} else {
let z_inv = z_invs[i];
AffinePoint::nonzero(x * z_inv, y * z_inv)
});
}
result
}
// From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/doubling/dbl-2007-bl
pub fn double(&self) -> Self {
let Self { x, y, z } = *self;
if z == C::BaseField::ZERO {
return ProjectivePoint::ZERO;
}
let xx = x.square();
let zz = z.square();
let mut w = xx.triple();
if C::A.is_nonzero() {
w += C::A * zz;
}
let s = y.double() * z;
let r = y * s;
let rr = r.square();
let b = (x + r).square() - (xx + rr);
let h = w.square() - b.double();
let x3 = h * s;
let y3 = w * (b - h) - rr.double();
let z3 = s.cube();
Self {
x: x3,
y: y3,
z: z3,
}
}
pub fn add_slices(a: &[Self], b: &[Self]) -> Vec<Self> {
assert_eq!(a.len(), b.len());
a.iter()
.zip(b.iter())
.map(|(&a_i, &b_i)| a_i + b_i)
.collect()
}
pub fn neg(&self) -> Self {
Self {
x: self.x,
y: -self.y,
z: self.z,
}
}
}
impl<C: Curve> PartialEq for ProjectivePoint<C> {
fn eq(&self, other: &Self) -> bool {
let ProjectivePoint {
x: x1,
y: y1,
z: z1,
} = *self;
let ProjectivePoint {
x: x2,
y: y2,
z: z2,
} = *other;
if z1 == C::BaseField::ZERO || z2 == C::BaseField::ZERO {
return z1 == z2;
}
// We want to compare (x1/z1, y1/z1) == (x2/z2, y2/z2).
// But to avoid field division, it is better to compare (x1*z2, y1*z2) == (x2*z1, y2*z1).
x1 * z2 == x2 * z1 && y1 * z2 == y2 * z1
}
}
impl<C: Curve> Eq for ProjectivePoint<C> {}
impl<C: Curve> Neg for AffinePoint<C> {
type Output = AffinePoint<C>;
fn neg(self) -> Self::Output {
let AffinePoint { x, y, zero } = self;
AffinePoint { x, y: -y, zero }
}
}
impl<C: Curve> Neg for ProjectivePoint<C> {
type Output = ProjectivePoint<C>;
fn neg(self) -> Self::Output {
let ProjectivePoint { x, y, z } = self;
ProjectivePoint { x, y: -y, z }
}
}

6
src/curve/mod.rs Normal file
View File

@ -0,0 +1,6 @@
pub mod curve_adds;
pub mod curve_msm;
pub mod curve_multiplication;
pub mod curve_summation;
pub mod curve_types;
pub mod secp256k1;

98
src/curve/secp256k1.rs Normal file
View File

@ -0,0 +1,98 @@
use crate::curve::curve_types::{AffinePoint, Curve};
use crate::field::field_types::Field;
use crate::field::secp256k1_base::Secp256K1Base;
use crate::field::secp256k1_scalar::Secp256K1Scalar;
#[derive(Debug, Copy, Clone)]
pub struct Secp256K1;
impl Curve for Secp256K1 {
type BaseField = Secp256K1Base;
type ScalarField = Secp256K1Scalar;
const A: Secp256K1Base = Secp256K1Base::ZERO;
const B: Secp256K1Base = Secp256K1Base([7, 0, 0, 0]);
const GENERATOR_AFFINE: AffinePoint<Self> = AffinePoint {
x: SECP256K1_GENERATOR_X,
y: SECP256K1_GENERATOR_Y,
zero: false,
};
}
// 55066263022277343669578718895168534326250603453777594175500187360389116729240
const SECP256K1_GENERATOR_X: Secp256K1Base = Secp256K1Base([
0x59F2815B16F81798,
0x029BFCDB2DCE28D9,
0x55A06295CE870B07,
0x79BE667EF9DCBBAC,
]);
/// 32670510020758816978083085130507043184471273380659243275938904335757337482424
const SECP256K1_GENERATOR_Y: Secp256K1Base = Secp256K1Base([
0x9C47D08FFB10D4B8,
0xFD17B448A6855419,
0x5DA4FBFC0E1108A8,
0x483ADA7726A3C465,
]);
#[cfg(test)]
mod tests {
use num::BigUint;
use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint};
use crate::curve::secp256k1::Secp256K1;
use crate::field::field_types::Field;
use crate::field::secp256k1_scalar::Secp256K1Scalar;
#[test]
fn test_generator() {
let g = Secp256K1::GENERATOR_AFFINE;
assert!(g.is_valid());
let neg_g = AffinePoint::<Secp256K1> {
x: g.x,
y: -g.y,
zero: g.zero,
};
assert!(neg_g.is_valid());
}
#[test]
fn test_naive_multiplication() {
let g = Secp256K1::GENERATOR_PROJECTIVE;
let ten = Secp256K1Scalar::from_canonical_u64(10);
let product = mul_naive(ten, g);
let sum = g + g + g + g + g + g + g + g + g + g;
assert_eq!(product, sum);
}
#[test]
fn test_g1_multiplication() {
let lhs = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[
1111, 2222, 3333, 4444, 5555, 6666, 7777, 8888,
]));
assert_eq!(
Secp256K1::convert(lhs) * Secp256K1::GENERATOR_PROJECTIVE,
mul_naive(lhs, Secp256K1::GENERATOR_PROJECTIVE)
);
}
/// A simple, somewhat inefficient implementation of multiplication which is used as a reference
/// for correctness.
fn mul_naive(
lhs: Secp256K1Scalar,
rhs: ProjectivePoint<Secp256K1>,
) -> ProjectivePoint<Secp256K1> {
let mut g = rhs;
let mut sum = ProjectivePoint::ZERO;
for limb in lhs.to_biguint().to_u64_digits().iter() {
for j in 0..64 {
if (limb >> j & 1u64) != 0u64 {
sum = sum + g;
}
g = g.double();
}
}
sum
}
}

View File

@ -31,8 +31,6 @@ mod tests {
#[test] #[test]
fn distinct_cosets() { fn distinct_cosets() {
// TODO: Switch to a smaller test field so that collision rejection is likely to occur.
type F = GoldilocksField; type F = GoldilocksField;
const SUBGROUP_BITS: usize = 5; const SUBGROUP_BITS: usize = 5;
const NUM_SHIFTS: usize = 50; const NUM_SHIFTS: usize = 50;

View File

@ -160,12 +160,32 @@ impl<F: OEF<D>, const D: usize> PolynomialCoeffsAlgebra<F, D> {
.fold(ExtensionAlgebra::ZERO, |acc, &c| acc * x + c) .fold(ExtensionAlgebra::ZERO, |acc, &c| acc * x + c)
} }
/// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1.
pub fn eval_with_powers(&self, powers: &[ExtensionAlgebra<F, D>]) -> ExtensionAlgebra<F, D> {
debug_assert_eq!(self.coeffs.len(), powers.len() + 1);
let acc = self.coeffs[0];
self.coeffs[1..]
.iter()
.zip(powers)
.fold(acc, |acc, (&x, &c)| acc + c * x)
}
pub fn eval_base(&self, x: F) -> ExtensionAlgebra<F, D> { pub fn eval_base(&self, x: F) -> ExtensionAlgebra<F, D> {
self.coeffs self.coeffs
.iter() .iter()
.rev() .rev()
.fold(ExtensionAlgebra::ZERO, |acc, &c| acc.scalar_mul(x) + c) .fold(ExtensionAlgebra::ZERO, |acc, &c| acc.scalar_mul(x) + c)
} }
/// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1.
pub fn eval_base_with_powers(&self, powers: &[F]) -> ExtensionAlgebra<F, D> {
debug_assert_eq!(self.coeffs.len(), powers.len() + 1);
let acc = self.coeffs[0];
self.coeffs[1..]
.iter()
.zip(powers)
.fold(acc, |acc, (&x, &c)| acc + x.scalar_mul(c))
}
} }
#[cfg(test)] #[cfg(test)]

View File

@ -1,3 +1,4 @@
use crate::field::field_types::{Field, PrimeField};
use std::convert::TryInto; use std::convert::TryInto;
use crate::field::field_types::{Field, RichField}; use crate::field::field_types::{Field, RichField};

View File

@ -3,6 +3,7 @@ use std::iter::{Product, Sum};
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use num::bigint::BigUint; use num::bigint::BigUint;
use num::Integer;
use rand::Rng; use rand::Rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -49,26 +50,28 @@ impl<F: Extendable<2>> From<F> for QuadraticExtension<F> {
} }
impl<F: Extendable<2>> Field for QuadraticExtension<F> { impl<F: Extendable<2>> Field for QuadraticExtension<F> {
type PrimeField = F;
const ZERO: Self = Self([F::ZERO; 2]); const ZERO: Self = Self([F::ZERO; 2]);
const ONE: Self = Self([F::ONE, F::ZERO]); const ONE: Self = Self([F::ONE, F::ZERO]);
const TWO: Self = Self([F::TWO, F::ZERO]); const TWO: Self = Self([F::TWO, F::ZERO]);
const NEG_ONE: Self = Self([F::NEG_ONE, F::ZERO]); const NEG_ONE: Self = Self([F::NEG_ONE, F::ZERO]);
const CHARACTERISTIC: u64 = F::CHARACTERISTIC;
// `p^2 - 1 = (p - 1)(p + 1)`. The `p - 1` term has a two-adicity of `F::TWO_ADICITY`. As // `p^2 - 1 = (p - 1)(p + 1)`. The `p - 1` term has a two-adicity of `F::TWO_ADICITY`. As
// long as `F::TWO_ADICITY >= 2`, `p` can be written as `4n + 1`, so `p + 1` can be written as // long as `F::TWO_ADICITY >= 2`, `p` can be written as `4n + 1`, so `p + 1` can be written as
// `2(2n + 1)`, which has a 2-adicity of 1. // `2(2n + 1)`, which has a 2-adicity of 1.
const TWO_ADICITY: usize = F::TWO_ADICITY + 1; const TWO_ADICITY: usize = F::TWO_ADICITY + 1;
const CHARACTERISTIC_TWO_ADICITY: usize = F::CHARACTERISTIC_TWO_ADICITY;
const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR); const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR);
const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR); const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR);
const BITS: usize = F::BITS * 2;
fn order() -> BigUint { fn order() -> BigUint {
F::order() * F::order() F::order() * F::order()
} }
fn characteristic() -> BigUint {
F::characteristic()
}
#[inline(always)] #[inline(always)]
fn square(&self) -> Self { fn square(&self) -> Self {
@ -99,6 +102,15 @@ impl<F: Extendable<2>> Field for QuadraticExtension<F> {
)) ))
} }
fn from_biguint(n: BigUint) -> Self {
let (high, low) = n.div_rem(&F::order());
Self([F::from_biguint(low), F::from_biguint(high)])
}
fn to_biguint(&self) -> BigUint {
self.0[0].to_biguint() + F::order() * self.0[1].to_biguint()
}
fn from_canonical_u64(n: u64) -> Self { fn from_canonical_u64(n: u64) -> Self {
F::from_canonical_u64(n).into() F::from_canonical_u64(n).into()
} }

View File

@ -4,6 +4,7 @@ use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssi
use num::bigint::BigUint; use num::bigint::BigUint;
use num::traits::Pow; use num::traits::Pow;
use num::Integer;
use rand::Rng; use rand::Rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -50,27 +51,29 @@ impl<F: Extendable<4>> From<F> for QuarticExtension<F> {
} }
impl<F: Extendable<4>> Field for QuarticExtension<F> { impl<F: Extendable<4>> Field for QuarticExtension<F> {
type PrimeField = F;
const ZERO: Self = Self([F::ZERO; 4]); const ZERO: Self = Self([F::ZERO; 4]);
const ONE: Self = Self([F::ONE, F::ZERO, F::ZERO, F::ZERO]); const ONE: Self = Self([F::ONE, F::ZERO, F::ZERO, F::ZERO]);
const TWO: Self = Self([F::TWO, F::ZERO, F::ZERO, F::ZERO]); const TWO: Self = Self([F::TWO, F::ZERO, F::ZERO, F::ZERO]);
const NEG_ONE: Self = Self([F::NEG_ONE, F::ZERO, F::ZERO, F::ZERO]); const NEG_ONE: Self = Self([F::NEG_ONE, F::ZERO, F::ZERO, F::ZERO]);
const CHARACTERISTIC: u64 = F::ORDER;
// `p^4 - 1 = (p - 1)(p + 1)(p^2 + 1)`. The `p - 1` term has a two-adicity of `F::TWO_ADICITY`. // `p^4 - 1 = (p - 1)(p + 1)(p^2 + 1)`. The `p - 1` term has a two-adicity of `F::TWO_ADICITY`.
// As long as `F::TWO_ADICITY >= 2`, `p` can be written as `4n + 1`, so `p + 1` can be written as // As long as `F::TWO_ADICITY >= 2`, `p` can be written as `4n + 1`, so `p + 1` can be written as
// `2(2n + 1)`, which has a 2-adicity of 1. A similar argument can show that `p^2 + 1` also has // `2(2n + 1)`, which has a 2-adicity of 1. A similar argument can show that `p^2 + 1` also has
// a 2-adicity of 1. // a 2-adicity of 1.
const TWO_ADICITY: usize = F::TWO_ADICITY + 2; const TWO_ADICITY: usize = F::TWO_ADICITY + 2;
const CHARACTERISTIC_TWO_ADICITY: usize = F::CHARACTERISTIC_TWO_ADICITY;
const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR); const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR);
const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR); const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR);
const BITS: usize = F::BITS * 4;
fn order() -> BigUint { fn order() -> BigUint {
F::order().pow(4u32) F::order().pow(4u32)
} }
fn characteristic() -> BigUint {
F::characteristic()
}
#[inline(always)] #[inline(always)]
fn square(&self) -> Self { fn square(&self) -> Self {
@ -104,6 +107,26 @@ impl<F: Extendable<4>> Field for QuarticExtension<F> {
)) ))
} }
fn from_biguint(n: BigUint) -> Self {
let (rest, first) = n.div_rem(&F::order());
let (rest, second) = rest.div_rem(&F::order());
let (rest, third) = rest.div_rem(&F::order());
Self([
F::from_biguint(first),
F::from_biguint(second),
F::from_biguint(third),
F::from_biguint(rest),
])
}
fn to_biguint(&self) -> BigUint {
let mut result = self.0[3].to_biguint();
result = result * F::order() + self.0[2].to_biguint();
result = result * F::order() + self.0[1].to_biguint();
result = result * F::order() + self.0[0].to_biguint();
result
}
fn from_canonical_u64(n: u64) -> Self { fn from_canonical_u64(n: u64) -> Self {
F::from_canonical_u64(n).into() F::from_canonical_u64(n).into()
} }

View File

@ -1,4 +1,3 @@
use std::convert::{TryFrom, TryInto};
use std::ops::Range; use std::ops::Range;
use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::algebra::ExtensionAlgebra;
@ -33,6 +32,7 @@ impl<const D: usize> ExtensionTarget<D> {
let arr = self.to_target_array(); let arr = self.to_target_array();
let k = (F::order() - 1u32) / (D as u64); let k = (F::order() - 1u32) / (D as u64);
let z0 = F::Extension::W.exp_biguint(&(k * count as u64)); let z0 = F::Extension::W.exp_biguint(&(k * count as u64));
#[allow(clippy::needless_collect)]
let zs = z0 let zs = z0
.powers() .powers()
.take(D) .take(D)

View File

@ -5,8 +5,8 @@ use unroll::unroll_for_loops;
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::field::packable::Packable; use crate::field::packable::Packable;
use crate::field::packed_field::{PackedField, Singleton}; use crate::field::packed_field::PackedField;
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::util::{log2_strict, reverse_index_bits}; use crate::util::{log2_strict, reverse_index_bits};
pub(crate) type FftRootTable<F> = Vec<Vec<F>>; pub(crate) type FftRootTable<F> = Vec<Vec<F>>;
@ -38,7 +38,7 @@ fn fft_dispatch<F: Field>(
zero_factor: Option<usize>, zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>, root_table: Option<&FftRootTable<F>>,
) -> Vec<F> { ) -> Vec<F> {
let computed_root_table = if let Some(_) = root_table { let computed_root_table = if root_table.is_some() {
None None
} else { } else {
Some(fft_root_table(input.len())) Some(fft_root_table(input.len()))
@ -98,12 +98,12 @@ pub fn ifft_with_options<F: Field>(
/// Generic FFT implementation that works with both scalar and packed inputs. /// Generic FFT implementation that works with both scalar and packed inputs.
#[unroll_for_loops] #[unroll_for_loops]
fn fft_classic_simd<P: PackedField>( fn fft_classic_simd<P: PackedField>(
values: &mut [P::FieldType], values: &mut [P::Scalar],
r: usize, r: usize,
lg_n: usize, lg_n: usize,
root_table: &FftRootTable<P::FieldType>, root_table: &FftRootTable<P::Scalar>,
) { ) {
let lg_packed_width = P::LOG2_WIDTH; // 0 when P is a scalar. let lg_packed_width = log2_strict(P::WIDTH); // 0 when P is a scalar.
let packed_values = P::pack_slice_mut(values); let packed_values = P::pack_slice_mut(values);
let packed_n = packed_values.len(); let packed_n = packed_values.len();
debug_assert!(packed_n == 1 << (lg_n - lg_packed_width)); debug_assert!(packed_n == 1 << (lg_n - lg_packed_width));
@ -121,19 +121,18 @@ fn fft_classic_simd<P: PackedField>(
let half_m = 1 << lg_half_m; let half_m = 1 << lg_half_m;
// Set omega to root_table[lg_half_m][0..half_m] but repeated. // Set omega to root_table[lg_half_m][0..half_m] but repeated.
let mut omega_vec = P::zero().to_vec(); let mut omega = P::ZERO;
for j in 0..omega_vec.len() { for (j, omega_j) in omega.as_slice_mut().iter_mut().enumerate() {
omega_vec[j] = root_table[lg_half_m][j % half_m]; *omega_j = root_table[lg_half_m][j % half_m];
} }
let omega = P::from_slice(&omega_vec[..]);
for k in (0..packed_n).step_by(2) { for k in (0..packed_n).step_by(2) {
// We have two vectors and want to do math on pairs of adjacent elements (or for // We have two vectors and want to do math on pairs of adjacent elements (or for
// lg_half_m > 0, pairs of adjacent blocks of elements). .interleave does the // lg_half_m > 0, pairs of adjacent blocks of elements). .interleave does the
// appropriate shuffling and is its own inverse. // appropriate shuffling and is its own inverse.
let (u, v) = packed_values[k].interleave(packed_values[k + 1], lg_half_m); let (u, v) = packed_values[k].interleave(packed_values[k + 1], half_m);
let t = omega * v; let t = omega * v;
(packed_values[k], packed_values[k + 1]) = (u + t).interleave(u - t, lg_half_m); (packed_values[k], packed_values[k + 1]) = (u + t).interleave(u - t, half_m);
} }
} }
} }
@ -197,13 +196,13 @@ pub(crate) fn fft_classic<F: Field>(input: &[F], r: usize, root_table: &FftRootT
} }
} }
let lg_packed_width = <F as Packable>::PackedType::LOG2_WIDTH; let lg_packed_width = log2_strict(<F as Packable>::Packing::WIDTH);
if lg_n <= lg_packed_width { if lg_n <= lg_packed_width {
// Need the slice to be at least the width of two packed vectors for the vectorized version // Need the slice to be at least the width of two packed vectors for the vectorized version
// to work. Do this tiny problem in scalar. // to work. Do this tiny problem in scalar.
fft_classic_simd::<Singleton<F>>(&mut values[..], r, lg_n, &root_table); fft_classic_simd::<F>(&mut values[..], r, lg_n, root_table);
} else { } else {
fft_classic_simd::<<F as Packable>::PackedType>(&mut values[..], r, lg_n, &root_table); fft_classic_simd::<<F as Packable>::Packing>(&mut values[..], r, lg_n, root_table);
} }
values values
} }
@ -213,19 +212,23 @@ mod tests {
use crate::field::fft::{fft, fft_with_options, ifft}; use crate::field::fft::{fft, fft_with_options, ifft};
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField; use crate::field::goldilocks_field::GoldilocksField;
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::util::{log2_ceil, log2_strict}; use crate::util::{log2_ceil, log2_strict};
#[test] #[test]
fn fft_and_ifft() { fn fft_and_ifft() {
type F = GoldilocksField; type F = GoldilocksField;
let degree = 200; let degree = 200usize;
let degree_padded = log2_ceil(degree); let degree_padded = degree.next_power_of_two();
let mut coefficients = Vec::new();
for i in 0..degree { // Create a vector of coeffs; the first degree of them are
coefficients.push(F::from_canonical_usize(i * 1337 % 100)); // "random", the last degree_padded-degree of them are zero.
} let coeffs = (0..degree)
let coefficients = PolynomialCoeffs::new_padded(coefficients); .map(|i| F::from_canonical_usize(i * 1337 % 100))
.chain(std::iter::repeat(F::ZERO).take(degree_padded - degree))
.collect::<Vec<_>>();
assert_eq!(coeffs.len(), degree_padded);
let coefficients = PolynomialCoeffs { coeffs };
let points = fft(&coefficients); let points = fft(&coefficients);
assert_eq!(points, evaluate_naive(&coefficients)); assert_eq!(points, evaluate_naive(&coefficients));
@ -263,7 +266,7 @@ mod tests {
let values = subgroup let values = subgroup
.into_iter() .into_iter()
.map(|x| evaluate_at_naive(&coefficients, x)) .map(|x| evaluate_at_naive(coefficients, x))
.collect(); .collect();
PolynomialValues::new(values) PolynomialValues::new(values)
} }
@ -272,8 +275,8 @@ mod tests {
let mut sum = F::ZERO; let mut sum = F::ZERO;
let mut point_power = F::ONE; let mut point_power = F::ONE;
for &c in &coefficients.coeffs { for &c in &coefficients.coeffs {
sum = sum + c * point_power; sum += c * point_power;
point_power = point_power * point; point_power *= point;
} }
sum sum
} }

View File

@ -13,14 +13,17 @@ macro_rules! test_field_arithmetic {
#[test] #[test]
fn batch_inversion() { fn batch_inversion() {
let xs = (1..=3) for n in 0..20 {
let xs = (1..=n as u64)
.map(|i| <$field>::from_canonical_u64(i)) .map(|i| <$field>::from_canonical_u64(i))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let invs = <$field>::batch_multiplicative_inverse(&xs); let invs = <$field>::batch_multiplicative_inverse(&xs);
assert_eq!(invs.len(), n);
for (x, inv) in xs.into_iter().zip(invs) { for (x, inv) in xs.into_iter().zip(invs) {
assert_eq!(x * inv, <$field>::ONE); assert_eq!(x * inv, <$field>::ONE);
} }
} }
}
#[test] #[test]
fn primitive_root_order() { fn primitive_root_order() {
@ -81,10 +84,24 @@ macro_rules! test_field_arithmetic {
assert_eq!(base.exp_biguint(&pow), base.exp_biguint(&big_pow)); assert_eq!(base.exp_biguint(&pow), base.exp_biguint(&big_pow));
assert_ne!(base.exp_biguint(&pow), base.exp_biguint(&big_pow_wrong)); assert_ne!(base.exp_biguint(&pow), base.exp_biguint(&big_pow_wrong));
} }
#[test]
fn inverses() {
type F = $field;
let x = F::rand();
let x1 = x.inverse();
let x2 = x1.inverse();
let x3 = x2.inverse();
assert_eq!(x, x2);
assert_eq!(x1, x3);
}
} }
}; };
} }
#[allow(clippy::eq_op)]
pub(crate) fn test_add_neg_sub_mul<BF: Extendable<D>, const D: usize>() { pub(crate) fn test_add_neg_sub_mul<BF: Extendable<D>, const D: usize>() {
let x = BF::Extension::rand(); let x = BF::Extension::rand();
let y = BF::Extension::rand(); let y = BF::Extension::rand();

View File

@ -1,11 +1,10 @@
use std::convert::TryInto;
use std::fmt::{Debug, Display}; use std::fmt::{Debug, Display};
use std::hash::Hash; use std::hash::Hash;
use std::iter::{Product, Sum}; use std::iter::{Product, Sum};
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use num::bigint::BigUint; use num::bigint::BigUint;
use num::{Integer, One, Zero}; use num::{Integer, One, ToPrimitive, Zero};
use rand::Rng; use rand::Rng;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::Serialize; use serde::Serialize;
@ -43,24 +42,28 @@ pub trait Field:
+ Serialize + Serialize
+ DeserializeOwned + DeserializeOwned
{ {
type PrimeField: PrimeField;
const ZERO: Self; const ZERO: Self;
const ONE: Self; const ONE: Self;
const TWO: Self; const TWO: Self;
const NEG_ONE: Self; const NEG_ONE: Self;
const CHARACTERISTIC: u64;
/// The 2-adicity of this field's multiplicative group. /// The 2-adicity of this field's multiplicative group.
const TWO_ADICITY: usize; const TWO_ADICITY: usize;
/// The field's characteristic and it's 2-adicity.
/// Set to `None` when the characteristic doesn't fit in a u64.
const CHARACTERISTIC_TWO_ADICITY: usize;
/// Generator of the entire multiplicative group, i.e. all non-zero elements. /// Generator of the entire multiplicative group, i.e. all non-zero elements.
const MULTIPLICATIVE_GROUP_GENERATOR: Self; const MULTIPLICATIVE_GROUP_GENERATOR: Self;
/// Generator of a multiplicative subgroup of order `2^TWO_ADICITY`. /// Generator of a multiplicative subgroup of order `2^TWO_ADICITY`.
const POWER_OF_TWO_GENERATOR: Self; const POWER_OF_TWO_GENERATOR: Self;
/// The bit length of the field order.
const BITS: usize;
fn order() -> BigUint; fn order() -> BigUint;
fn characteristic() -> BigUint;
#[inline] #[inline]
fn is_zero(&self) -> bool { fn is_zero(&self) -> bool {
@ -92,6 +95,10 @@ pub trait Field:
self.square() * *self self.square() * *self
} }
fn triple(&self) -> Self {
*self * (Self::ONE + Self::TWO)
}
/// Compute the multiplicative inverse of this field element. /// Compute the multiplicative inverse of this field element.
fn try_inverse(&self) -> Option<Self>; fn try_inverse(&self) -> Option<Self>;
@ -103,34 +110,91 @@ pub trait Field:
// This is Montgomery's trick. At a high level, we invert the product of the given field // This is Montgomery's trick. At a high level, we invert the product of the given field
// elements, then derive the individual inverses from that via multiplication. // elements, then derive the individual inverses from that via multiplication.
// The usual Montgomery trick involves calculating an array of cumulative products,
// resulting in a long dependency chain. To increase instruction-level parallelism, we
// compute WIDTH separate cumulative product arrays that only meet at the end.
// Higher WIDTH increases instruction-level parallelism, but too high a value will cause us
// to run out of registers.
const WIDTH: usize = 4;
// JN note: WIDTH is 4. The code is specialized to this value and will need
// modification if it is changed. I tried to make it more generic, but Rust's const
// generics are not yet good enough.
// Handle special cases. Paradoxically, below is repetitive but concise.
// The branches should be very predictable.
let n = x.len(); let n = x.len();
if n == 0 { if n == 0 {
return Vec::new(); return Vec::new();
} } else if n == 1 {
if n == 1 {
return vec![x[0].inverse()]; return vec![x[0].inverse()];
} else if n == 2 {
let x01 = x[0] * x[1];
let x01inv = x01.inverse();
return vec![x01inv * x[1], x01inv * x[0]];
} else if n == 3 {
let x01 = x[0] * x[1];
let x012 = x01 * x[2];
let x012inv = x012.inverse();
let x01inv = x012inv * x[2];
return vec![x01inv * x[1], x01inv * x[0], x012inv * x01];
} }
debug_assert!(n >= WIDTH);
// Fill buf with cumulative product of x. // Buf is reused for a few things to save allocations.
let mut buf = Vec::with_capacity(n); // Fill buf with cumulative product of x, only taking every 4th value. Concretely, buf will
let mut cumul_prod = x[0]; // be [
buf.push(cumul_prod); // x[0], x[1], x[2], x[3],
for i in 1..n { // x[0] * x[4], x[1] * x[5], x[2] * x[6], x[3] * x[7],
cumul_prod *= x[i]; // x[0] * x[4] * x[8], x[1] * x[5] * x[9], x[2] * x[6] * x[10], x[3] * x[7] * x[11],
buf.push(cumul_prod); // ...
// ].
// If n is not a multiple of WIDTH, the result is truncated from the end. For example,
// for n == 5, we get [x[0], x[1], x[2], x[3], x[0] * x[4]].
let mut buf: Vec<Self> = Vec::with_capacity(n);
// cumul_prod holds the last WIDTH elements of buf. This is redundant, but it's how we
// convince LLVM to keep the values in the registers.
let mut cumul_prod: [Self; WIDTH] = x[..WIDTH].try_into().unwrap();
buf.extend(cumul_prod);
for (i, &xi) in x[WIDTH..].iter().enumerate() {
cumul_prod[i % WIDTH] *= xi;
buf.push(cumul_prod[i % WIDTH]);
} }
debug_assert_eq!(buf.len(), n);
// At this stage buf contains the the cumulative product of x. We reuse the buffer for let mut a_inv = {
// efficiency. At the end of the loop, it is filled with inverses of x. // This is where the four dependency chains meet.
let mut a_inv = cumul_prod.inverse(); // Take the last four elements of buf and invert them all.
buf[n - 1] = buf[n - 2] * a_inv; let c01 = cumul_prod[0] * cumul_prod[1];
for i in (1..n - 1).rev() { let c23 = cumul_prod[2] * cumul_prod[3];
a_inv = x[i + 1] * a_inv; let c0123 = c01 * c23;
// buf[i - 1] has not been written to by this loop, so it equals x[0] * ... x[n - 1]. let c0123inv = c0123.inverse();
buf[i] = buf[i - 1] * a_inv; let c01inv = c0123inv * c23;
let c23inv = c0123inv * c01;
[
c01inv * cumul_prod[1],
c01inv * cumul_prod[0],
c23inv * cumul_prod[3],
c23inv * cumul_prod[2],
]
};
for i in (WIDTH..n).rev() {
// buf[i - WIDTH] has not been written to by this loop, so it equals
// x[i % WIDTH] * x[i % WIDTH + WIDTH] * ... * x[i - WIDTH].
buf[i] = buf[i - WIDTH] * a_inv[i % WIDTH];
// buf[i] now holds the inverse of x[i]. // buf[i] now holds the inverse of x[i].
a_inv[i % WIDTH] *= x[i];
} }
buf[0] = x[1] * a_inv; for i in (0..WIDTH).rev() {
buf[i] = a_inv[i];
}
for (&bi, &xi) in buf.iter().zip(x) {
// Sanity check only.
debug_assert_eq!(bi * xi, Self::ONE);
}
buf buf
} }
@ -142,30 +206,32 @@ pub trait Field:
// exp exceeds t, we repeatedly multiply by 2^-t and reduce // exp exceeds t, we repeatedly multiply by 2^-t and reduce
// exp until it's in the right range. // exp until it's in the right range.
let p = Self::CHARACTERISTIC; if let Some(p) = Self::characteristic().to_u64() {
// NB: The only reason this is split into two cases is to save // NB: The only reason this is split into two cases is to save
// the multiplication (and possible calculation of // the multiplication (and possible calculation of
// inverse_2_pow_adicity) in the usual case that exp <= // inverse_2_pow_adicity) in the usual case that exp <=
// TWO_ADICITY. Can remove the branch and simplify if that // TWO_ADICITY. Can remove the branch and simplify if that
// saving isn't worth it. // saving isn't worth it.
if exp > Self::PrimeField::TWO_ADICITY { if exp > Self::CHARACTERISTIC_TWO_ADICITY {
// NB: This should be a compile-time constant // NB: This should be a compile-time constant
let inverse_2_pow_adicity: Self = let inverse_2_pow_adicity: Self =
Self::from_canonical_u64(p - ((p - 1) >> Self::PrimeField::TWO_ADICITY)); Self::from_canonical_u64(p - ((p - 1) >> Self::CHARACTERISTIC_TWO_ADICITY));
let mut res = inverse_2_pow_adicity; let mut res = inverse_2_pow_adicity;
let mut e = exp - Self::PrimeField::TWO_ADICITY; let mut e = exp - Self::CHARACTERISTIC_TWO_ADICITY;
while e > Self::PrimeField::TWO_ADICITY { while e > Self::CHARACTERISTIC_TWO_ADICITY {
res *= inverse_2_pow_adicity; res *= inverse_2_pow_adicity;
e -= Self::PrimeField::TWO_ADICITY; e -= Self::CHARACTERISTIC_TWO_ADICITY;
} }
res * Self::from_canonical_u64(p - ((p - 1) >> e)) res * Self::from_canonical_u64(p - ((p - 1) >> e))
} else { } else {
Self::from_canonical_u64(p - ((p - 1) >> exp)) Self::from_canonical_u64(p - ((p - 1) >> exp))
} }
} else {
Self::TWO.inverse().exp_u64(exp as u64)
}
} }
fn primitive_root_of_unity(n_log: usize) -> Self { fn primitive_root_of_unity(n_log: usize) -> Self {
@ -206,6 +272,11 @@ pub trait Field:
subgroup.into_iter().map(|x| x * shift).collect() subgroup.into_iter().map(|x| x * shift).collect()
} }
// TODO: move these to a new `PrimeField` trait (for all prime fields, not just 64-bit ones)
fn from_biguint(n: BigUint) -> Self;
fn to_biguint(&self) -> BigUint;
fn from_canonical_u64(n: u64) -> Self; fn from_canonical_u64(n: u64) -> Self;
fn from_canonical_u32(n: u32) -> Self { fn from_canonical_u32(n: u32) -> Self {
@ -274,7 +345,7 @@ pub trait Field:
} }
fn kth_root_u64(&self, k: u64) -> Self { fn kth_root_u64(&self, k: u64) -> Self {
let p = Self::order().clone(); let p = Self::order();
let p_minus_1 = &p - 1u32; let p_minus_1 = &p - 1u32;
debug_assert!( debug_assert!(
Self::is_monomial_permutation_u64(k), Self::is_monomial_permutation_u64(k),
@ -356,6 +427,7 @@ pub trait PrimeField: Field {
unsafe { self.sub_canonical_u64(1) } unsafe { self.sub_canonical_u64(1) }
} }
/// # Safety
/// Equivalent to *self + Self::from_canonical_u64(rhs), but may be cheaper. The caller must /// Equivalent to *self + Self::from_canonical_u64(rhs), but may be cheaper. The caller must
/// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this /// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this
/// precondition is not met. It is marked unsafe for this reason. /// precondition is not met. It is marked unsafe for this reason.
@ -365,6 +437,7 @@ pub trait PrimeField: Field {
*self + Self::from_canonical_u64(rhs) *self + Self::from_canonical_u64(rhs)
} }
/// # Safety
/// Equivalent to *self - Self::from_canonical_u64(rhs), but may be cheaper. The caller must /// Equivalent to *self - Self::from_canonical_u64(rhs), but may be cheaper. The caller must
/// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this /// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this
/// precondition is not met. It is marked unsafe for this reason. /// precondition is not met. It is marked unsafe for this reason.

View File

@ -4,7 +4,7 @@ use std::hash::{Hash, Hasher};
use std::iter::{Product, Sum}; use std::iter::{Product, Sum};
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use num::BigUint; use num::{BigUint, Integer};
use rand::Rng; use rand::Rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -62,15 +62,13 @@ impl Debug for GoldilocksField {
} }
impl Field for GoldilocksField { impl Field for GoldilocksField {
type PrimeField = Self;
const ZERO: Self = Self(0); const ZERO: Self = Self(0);
const ONE: Self = Self(1); const ONE: Self = Self(1);
const TWO: Self = Self(2); const TWO: Self = Self(2);
const NEG_ONE: Self = Self(Self::ORDER - 1); const NEG_ONE: Self = Self(Self::ORDER - 1);
const CHARACTERISTIC: u64 = Self::ORDER;
const TWO_ADICITY: usize = 32; const TWO_ADICITY: usize = 32;
const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY;
// Sage: `g = GF(p).multiplicative_generator()` // Sage: `g = GF(p).multiplicative_generator()`
const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(7); const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(7);
@ -82,15 +80,28 @@ impl Field for GoldilocksField {
// ``` // ```
const POWER_OF_TWO_GENERATOR: Self = Self(1753635133440165772); const POWER_OF_TWO_GENERATOR: Self = Self(1753635133440165772);
const BITS: usize = 64;
fn order() -> BigUint { fn order() -> BigUint {
Self::ORDER.into() Self::ORDER.into()
} }
fn characteristic() -> BigUint {
Self::order()
}
#[inline(always)] #[inline(always)]
fn try_inverse(&self) -> Option<Self> { fn try_inverse(&self) -> Option<Self> {
try_inverse_u64(self) try_inverse_u64(self)
} }
fn from_biguint(n: BigUint) -> Self {
Self(n.mod_floor(&Self::order()).to_u64_digits()[0])
}
fn to_biguint(&self) -> BigUint {
self.to_canonical_u64().into()
}
#[inline] #[inline]
fn from_canonical_u64(n: u64) -> Self { fn from_canonical_u64(n: u64) -> Self {
debug_assert!(n < Self::ORDER); debug_assert!(n < Self::ORDER);
@ -312,6 +323,7 @@ impl RichField for GoldilocksField {}
#[inline(always)] #[inline(always)]
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 {
use std::arch::asm;
let res_wrapped: u64; let res_wrapped: u64;
let adjustment: u64; let adjustment: u64;
asm!( asm!(
@ -352,6 +364,7 @@ unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 {
#[inline(always)] #[inline(always)]
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
unsafe fn sub_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { unsafe fn sub_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 {
use std::arch::asm;
let res_wrapped: u64; let res_wrapped: u64;
let adjustment: u64; let adjustment: u64;
asm!( asm!(

View File

@ -1,6 +1,6 @@
use crate::field::fft::ifft; use crate::field::fft::ifft;
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::util::log2_ceil; use crate::util::log2_ceil;
/// Computes the unique degree < n interpolant of an arbitrary list of n (point, value) pairs. /// Computes the unique degree < n interpolant of an arbitrary list of n (point, value) pairs.
@ -80,7 +80,7 @@ mod tests {
use crate::field::extension_field::quartic::QuarticExtension; use crate::field::extension_field::quartic::QuarticExtension;
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField; use crate::field::goldilocks_field::GoldilocksField;
use crate::polynomial::polynomial::PolynomialCoeffs; use crate::polynomial::PolynomialCoeffs;
#[test] #[test]
fn interpolant_random() { fn interpolant_random() {

View File

@ -7,7 +7,8 @@ pub(crate) mod interpolation;
mod inversion; mod inversion;
pub(crate) mod packable; pub(crate) mod packable;
pub(crate) mod packed_field; pub(crate) mod packed_field;
pub mod secp256k1; pub mod secp256k1_base;
pub mod secp256k1_scalar;
#[cfg(target_feature = "avx2")] #[cfg(target_feature = "avx2")]
pub(crate) mod packed_avx2; pub(crate) mod packed_avx2;

View File

@ -1,18 +1,18 @@
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::field::packed_field::{PackedField, Singleton}; use crate::field::packed_field::PackedField;
/// Points us to the default packing for a particular field. There may me multiple choices of /// Points us to the default packing for a particular field. There may me multiple choices of
/// PackedField for a particular Field (e.g. Singleton works for all fields), but this is the /// PackedField for a particular Field (e.g. every Field is also a PackedField), but this is the
/// recommended one. The recommended packing varies by target_arch and target_feature. /// recommended one. The recommended packing varies by target_arch and target_feature.
pub trait Packable: Field { pub trait Packable: Field {
type PackedType: PackedField<FieldType = Self>; type Packing: PackedField<Scalar = Self>;
} }
impl<F: Field> Packable for F { impl<F: Field> Packable for F {
default type PackedType = Singleton<Self>; default type Packing = Self;
} }
#[cfg(target_feature = "avx2")] #[cfg(target_feature = "avx2")]
impl Packable for crate::field::goldilocks_field::GoldilocksField { impl Packable for crate::field::goldilocks_field::GoldilocksField {
type PackedType = crate::field::packed_avx2::PackedGoldilocksAVX2; type Packing = crate::field::packed_avx2::PackedGoldilocksAvx2;
} }

View File

@ -2,20 +2,20 @@ use core::arch::x86_64::*;
use std::fmt; use std::fmt;
use std::fmt::{Debug, Formatter}; use std::fmt::{Debug, Formatter};
use std::iter::{Product, Sum}; use std::iter::{Product, Sum};
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use crate::field::field_types::PrimeField; use crate::field::field_types::PrimeField;
use crate::field::packed_avx2::common::{ use crate::field::packed_avx2::common::{
add_no_canonicalize_64_64s_s, epsilon, field_order, ReducibleAVX2, add_no_canonicalize_64_64s_s, epsilon, field_order, shift, ReducibleAvx2,
}; };
use crate::field::packed_field::PackedField; use crate::field::packed_field::PackedField;
// PackedPrimeField wraps an array of four u64s, with the new and get methods to convert that // Avx2PrimeField wraps an array of four u64s, with the new and get methods to convert that
// array to and from __m256i, which is the type we actually operate on. This indirection is a // array to and from __m256i, which is the type we actually operate on. This indirection is a
// terrible trick to change PackedPrimeField's alignment. // terrible trick to change Avx2PrimeField's alignment.
// We'd like to be able to cast slices of PrimeField to slices of PackedPrimeField. Rust // We'd like to be able to cast slices of PrimeField to slices of Avx2PrimeField. Rust
// aligns __m256i to 32 bytes but PrimeField has a lower alignment. That alignment extends to // aligns __m256i to 32 bytes but PrimeField has a lower alignment. That alignment extends to
// PackedPrimeField and it appears that it cannot be lowered with #[repr(C, blah)]. It is // Avx2PrimeField and it appears that it cannot be lowered with #[repr(C, blah)]. It is
// important for Rust not to assume 32-byte alignment, so we cannot wrap __m256i directly. // important for Rust not to assume 32-byte alignment, so we cannot wrap __m256i directly.
// There are two versions of vectorized load/store instructions on x86: aligned (vmovaps and // There are two versions of vectorized load/store instructions on x86: aligned (vmovaps and
// friends) and unaligned (vmovups etc.). The difference between them is that aligned loads and // friends) and unaligned (vmovups etc.). The difference between them is that aligned loads and
@ -23,12 +23,12 @@ use crate::field::packed_field::PackedField;
// were faster, and although this is no longer the case, compilers prefer the aligned versions if // were faster, and although this is no longer the case, compilers prefer the aligned versions if
// they know that the address is aligned. Using aligned instructions on unaligned addresses leads to // they know that the address is aligned. Using aligned instructions on unaligned addresses leads to
// bugs that can be frustrating to diagnose. Hence, we can't have Rust assuming alignment, and // bugs that can be frustrating to diagnose. Hence, we can't have Rust assuming alignment, and
// therefore PackedPrimeField wraps [F; 4] and not __m256i. // therefore Avx2PrimeField wraps [F; 4] and not __m256i.
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
#[repr(transparent)] #[repr(transparent)]
pub struct PackedPrimeField<F: ReducibleAVX2>(pub [F; 4]); pub struct Avx2PrimeField<F: ReducibleAvx2>(pub [F; 4]);
impl<F: ReducibleAVX2> PackedPrimeField<F> { impl<F: ReducibleAvx2> Avx2PrimeField<F> {
#[inline] #[inline]
fn new(x: __m256i) -> Self { fn new(x: __m256i) -> Self {
let mut obj = Self([F::ZERO; 4]); let mut obj = Self([F::ZERO; 4]);
@ -43,84 +43,111 @@ impl<F: ReducibleAVX2> PackedPrimeField<F> {
let ptr = (&self.0).as_ptr().cast::<__m256i>(); let ptr = (&self.0).as_ptr().cast::<__m256i>();
unsafe { _mm256_loadu_si256(ptr) } unsafe { _mm256_loadu_si256(ptr) }
} }
/// Addition that assumes x + y < 2^64 + F::ORDER. May return incorrect results if this
/// condition is not met, hence it is marked unsafe.
#[inline]
pub unsafe fn add_canonical_u64(&self, rhs: __m256i) -> Self {
Self::new(add_canonical_u64::<F>(self.get(), rhs))
}
} }
impl<F: ReducibleAVX2> Add<Self> for PackedPrimeField<F> { impl<F: ReducibleAvx2> Add<Self> for Avx2PrimeField<F> {
type Output = Self; type Output = Self;
#[inline] #[inline]
fn add(self, rhs: Self) -> Self { fn add(self, rhs: Self) -> Self {
Self::new(unsafe { add::<F>(self.get(), rhs.get()) }) Self::new(unsafe { add::<F>(self.get(), rhs.get()) })
} }
} }
impl<F: ReducibleAVX2> Add<F> for PackedPrimeField<F> { impl<F: ReducibleAvx2> Add<F> for Avx2PrimeField<F> {
type Output = Self; type Output = Self;
#[inline] #[inline]
fn add(self, rhs: F) -> Self { fn add(self, rhs: F) -> Self {
self + Self::broadcast(rhs) self + Self::from(rhs)
} }
} }
impl<F: ReducibleAVX2> AddAssign<Self> for PackedPrimeField<F> { impl<F: ReducibleAvx2> Add<Avx2PrimeField<F>> for <Avx2PrimeField<F> as PackedField>::Scalar {
type Output = Avx2PrimeField<F>;
#[inline]
fn add(self, rhs: Self::Output) -> Self::Output {
Self::Output::from(self) + rhs
}
}
impl<F: ReducibleAvx2> AddAssign<Self> for Avx2PrimeField<F> {
#[inline] #[inline]
fn add_assign(&mut self, rhs: Self) { fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs; *self = *self + rhs;
} }
} }
impl<F: ReducibleAVX2> AddAssign<F> for PackedPrimeField<F> { impl<F: ReducibleAvx2> AddAssign<F> for Avx2PrimeField<F> {
#[inline] #[inline]
fn add_assign(&mut self, rhs: F) { fn add_assign(&mut self, rhs: F) {
*self = *self + rhs; *self = *self + rhs;
} }
} }
impl<F: ReducibleAVX2> Debug for PackedPrimeField<F> { impl<F: ReducibleAvx2> Debug for Avx2PrimeField<F> {
#[inline] #[inline]
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "({:?})", self.get()) write!(f, "({:?})", self.get())
} }
} }
impl<F: ReducibleAVX2> Default for PackedPrimeField<F> { impl<F: ReducibleAvx2> Default for Avx2PrimeField<F> {
#[inline] #[inline]
fn default() -> Self { fn default() -> Self {
Self::zero() Self::ZERO
} }
} }
impl<F: ReducibleAVX2> Mul<Self> for PackedPrimeField<F> { impl<F: ReducibleAvx2> Div<F> for Avx2PrimeField<F> {
type Output = Self;
#[inline]
fn div(self, rhs: F) -> Self {
self * rhs.inverse()
}
}
impl<F: ReducibleAvx2> DivAssign<F> for Avx2PrimeField<F> {
#[inline]
fn div_assign(&mut self, rhs: F) {
*self *= rhs.inverse();
}
}
impl<F: ReducibleAvx2> From<F> for Avx2PrimeField<F> {
fn from(x: F) -> Self {
Self([x; 4])
}
}
impl<F: ReducibleAvx2> Mul<Self> for Avx2PrimeField<F> {
type Output = Self; type Output = Self;
#[inline] #[inline]
fn mul(self, rhs: Self) -> Self { fn mul(self, rhs: Self) -> Self {
Self::new(unsafe { mul::<F>(self.get(), rhs.get()) }) Self::new(unsafe { mul::<F>(self.get(), rhs.get()) })
} }
} }
impl<F: ReducibleAVX2> Mul<F> for PackedPrimeField<F> { impl<F: ReducibleAvx2> Mul<F> for Avx2PrimeField<F> {
type Output = Self; type Output = Self;
#[inline] #[inline]
fn mul(self, rhs: F) -> Self { fn mul(self, rhs: F) -> Self {
self * Self::broadcast(rhs) self * Self::from(rhs)
} }
} }
impl<F: ReducibleAVX2> MulAssign<Self> for PackedPrimeField<F> { impl<F: ReducibleAvx2> Mul<Avx2PrimeField<F>> for <Avx2PrimeField<F> as PackedField>::Scalar {
type Output = Avx2PrimeField<F>;
#[inline]
fn mul(self, rhs: Avx2PrimeField<F>) -> Self::Output {
Self::Output::from(self) * rhs
}
}
impl<F: ReducibleAvx2> MulAssign<Self> for Avx2PrimeField<F> {
#[inline] #[inline]
fn mul_assign(&mut self, rhs: Self) { fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs; *self = *self * rhs;
} }
} }
impl<F: ReducibleAVX2> MulAssign<F> for PackedPrimeField<F> { impl<F: ReducibleAvx2> MulAssign<F> for Avx2PrimeField<F> {
#[inline] #[inline]
fn mul_assign(&mut self, rhs: F) { fn mul_assign(&mut self, rhs: F) {
*self = *self * rhs; *self = *self * rhs;
} }
} }
impl<F: ReducibleAVX2> Neg for PackedPrimeField<F> { impl<F: ReducibleAvx2> Neg for Avx2PrimeField<F> {
type Output = Self; type Output = Self;
#[inline] #[inline]
fn neg(self) -> Self { fn neg(self) -> Self {
@ -128,52 +155,59 @@ impl<F: ReducibleAVX2> Neg for PackedPrimeField<F> {
} }
} }
impl<F: ReducibleAVX2> Product for PackedPrimeField<F> { impl<F: ReducibleAvx2> Product for Avx2PrimeField<F> {
#[inline] #[inline]
fn product<I: Iterator<Item = Self>>(iter: I) -> Self { fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.reduce(|x, y| x * y).unwrap_or(Self::one()) iter.reduce(|x, y| x * y).unwrap_or(Self::ONE)
} }
} }
impl<F: ReducibleAVX2> PackedField for PackedPrimeField<F> { unsafe impl<F: ReducibleAvx2> PackedField for Avx2PrimeField<F> {
const LOG2_WIDTH: usize = 2; const WIDTH: usize = 4;
type FieldType = F; type Scalar = F;
type PackedPrimeField = Avx2PrimeField<F>;
const ZERO: Self = Self([F::ZERO; 4]);
const ONE: Self = Self([F::ONE; 4]);
#[inline] #[inline]
fn broadcast(x: F) -> Self { fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self {
Self([x; 4])
}
#[inline]
fn from_arr(arr: [F; Self::WIDTH]) -> Self {
Self(arr) Self(arr)
} }
#[inline] #[inline]
fn to_arr(&self) -> [F; Self::WIDTH] { fn as_arr(&self) -> [Self::Scalar; Self::WIDTH] {
self.0 self.0
} }
#[inline] #[inline]
fn from_slice(slice: &[F]) -> Self { fn from_slice(slice: &[Self::Scalar]) -> &Self {
assert!(slice.len() == 4); assert_eq!(slice.len(), Self::WIDTH);
Self([slice[0], slice[1], slice[2], slice[3]]) unsafe { &*slice.as_ptr().cast() }
}
#[inline]
fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self {
assert_eq!(slice.len(), Self::WIDTH);
unsafe { &mut *slice.as_mut_ptr().cast() }
}
#[inline]
fn as_slice(&self) -> &[Self::Scalar] {
&self.0[..]
}
#[inline]
fn as_slice_mut(&mut self) -> &mut [Self::Scalar] {
&mut self.0[..]
} }
#[inline] #[inline]
fn to_vec(&self) -> Vec<F> { fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) {
self.0.into()
}
#[inline]
fn interleave(&self, other: Self, r: usize) -> (Self, Self) {
let (v0, v1) = (self.get(), other.get()); let (v0, v1) = (self.get(), other.get());
let (res0, res1) = match r { let (res0, res1) = match block_len {
0 => unsafe { interleave0(v0, v1) },
1 => unsafe { interleave1(v0, v1) }, 1 => unsafe { interleave1(v0, v1) },
2 => (v0, v1), 2 => unsafe { interleave2(v0, v1) },
_ => panic!("r cannot be more than LOG2_WIDTH"), 4 => (v0, v1),
_ => panic!("unsupported block_len"),
}; };
(Self::new(res0), Self::new(res1)) (Self::new(res0), Self::new(res1))
} }
@ -184,47 +218,47 @@ impl<F: ReducibleAVX2> PackedField for PackedPrimeField<F> {
} }
} }
impl<F: ReducibleAVX2> Sub<Self> for PackedPrimeField<F> { impl<F: ReducibleAvx2> Sub<Self> for Avx2PrimeField<F> {
type Output = Self; type Output = Self;
#[inline] #[inline]
fn sub(self, rhs: Self) -> Self { fn sub(self, rhs: Self) -> Self {
Self::new(unsafe { sub::<F>(self.get(), rhs.get()) }) Self::new(unsafe { sub::<F>(self.get(), rhs.get()) })
} }
} }
impl<F: ReducibleAVX2> Sub<F> for PackedPrimeField<F> { impl<F: ReducibleAvx2> Sub<F> for Avx2PrimeField<F> {
type Output = Self; type Output = Self;
#[inline] #[inline]
fn sub(self, rhs: F) -> Self { fn sub(self, rhs: F) -> Self {
self - Self::broadcast(rhs) self - Self::from(rhs)
} }
} }
impl<F: ReducibleAVX2> SubAssign<Self> for PackedPrimeField<F> { impl<F: ReducibleAvx2> Sub<Avx2PrimeField<F>> for <Avx2PrimeField<F> as PackedField>::Scalar {
type Output = Avx2PrimeField<F>;
#[inline]
fn sub(self, rhs: Avx2PrimeField<F>) -> Self::Output {
Self::Output::from(self) - rhs
}
}
impl<F: ReducibleAvx2> SubAssign<Self> for Avx2PrimeField<F> {
#[inline] #[inline]
fn sub_assign(&mut self, rhs: Self) { fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs; *self = *self - rhs;
} }
} }
impl<F: ReducibleAVX2> SubAssign<F> for PackedPrimeField<F> { impl<F: ReducibleAvx2> SubAssign<F> for Avx2PrimeField<F> {
#[inline] #[inline]
fn sub_assign(&mut self, rhs: F) { fn sub_assign(&mut self, rhs: F) {
*self = *self - rhs; *self = *self - rhs;
} }
} }
impl<F: ReducibleAVX2> Sum for PackedPrimeField<F> { impl<F: ReducibleAvx2> Sum for Avx2PrimeField<F> {
#[inline] #[inline]
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self { fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.reduce(|x, y| x + y).unwrap_or(Self::zero()) iter.reduce(|x, y| x + y).unwrap_or(Self::ZERO)
} }
} }
const SIGN_BIT: u64 = 1 << 63;
#[inline]
unsafe fn sign_bit() -> __m256i {
_mm256_set1_epi64x(SIGN_BIT as i64)
}
// Resources: // Resources:
// 1. Intel Intrinsics Guide for explanation of each intrinsic: // 1. Intel Intrinsics Guide for explanation of each intrinsic:
// https://software.intel.com/sites/landingpage/IntrinsicsGuide/ // https://software.intel.com/sites/landingpage/IntrinsicsGuide/
@ -274,12 +308,6 @@ unsafe fn sign_bit() -> __m256i {
// Notice that the above 3-value addition still only requires two calls to shift, just like our // Notice that the above 3-value addition still only requires two calls to shift, just like our
// 2-value addition. // 2-value addition.
/// Add 2^63 with overflow. Needed to emulate unsigned comparisons (see point 3. above).
#[inline]
unsafe fn shift(x: __m256i) -> __m256i {
_mm256_xor_si256(x, sign_bit())
}
/// Convert to canonical representation. /// Convert to canonical representation.
/// The argument is assumed to be shifted by 1 << 63 (i.e. x_s = x + 1<<63, where x is the field /// The argument is assumed to be shifted by 1 << 63 (i.e. x_s = x + 1<<63, where x is the field
/// value). The returned value is similarly shifted by 1 << 63 (i.e. we return y_s = y + (1<<63), /// value). The returned value is similarly shifted by 1 << 63 (i.e. we return y_s = y + (1<<63),
@ -293,14 +321,6 @@ unsafe fn canonicalize_s<F: PrimeField>(x_s: __m256i) -> __m256i {
_mm256_add_epi64(x_s, wrapback_amt) _mm256_add_epi64(x_s, wrapback_amt)
} }
/// Addition that assumes x + y < 2^64 + F::ORDER.
#[inline]
unsafe fn add_canonical_u64<F: PrimeField>(x: __m256i, y: __m256i) -> __m256i {
let y_s = shift(y);
let res_s = add_no_canonicalize_64_64s_s::<F>(x, y_s);
shift(res_s)
}
#[inline] #[inline]
unsafe fn add<F: PrimeField>(x: __m256i, y: __m256i) -> __m256i { unsafe fn add<F: PrimeField>(x: __m256i, y: __m256i) -> __m256i {
let y_s = shift(y); let y_s = shift(y);
@ -326,78 +346,94 @@ unsafe fn neg<F: PrimeField>(y: __m256i) -> __m256i {
_mm256_sub_epi64(shift(field_order::<F>()), canonicalize_s::<F>(y_s)) _mm256_sub_epi64(shift(field_order::<F>()), canonicalize_s::<F>(y_s))
} }
/// Full 64-bit by 64-bit multiplication. This emulated multiplication is 1.5x slower than the /// Full 64-bit by 64-bit multiplication. This emulated multiplication is 1.33x slower than the
/// scalar instruction, but may be worth it if we want our data to live in vector registers. /// scalar instruction, but may be worth it if we want our data to live in vector registers.
#[inline] #[inline]
unsafe fn mul64_64_s(x: __m256i, y: __m256i) -> (__m256i, __m256i) { unsafe fn mul64_64(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
let x_hi = _mm256_srli_epi64(x, 32); // We want to move the high 32 bits to the low position. The multiplication instruction ignores
let y_hi = _mm256_srli_epi64(y, 32); // the high 32 bits, so it's ok to just duplicate it into the low position. This duplication can
// be done on port 5; bitshifts run on ports 0 and 1, competing with multiplication.
// This instruction is only provided for 32-bit floats, not integers. Idk why Intel makes the
// distinction; the casts are free and it guarantees that the exact bit pattern is preserved.
// Using a swizzle instruction of the wrong domain (float vs int) does not increase latency
// since Haswell.
let x_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x)));
let y_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(y)));
// All four pairwise multiplications
let mul_ll = _mm256_mul_epu32(x, y); let mul_ll = _mm256_mul_epu32(x, y);
let mul_lh = _mm256_mul_epu32(x, y_hi); let mul_lh = _mm256_mul_epu32(x, y_hi);
let mul_hl = _mm256_mul_epu32(x_hi, y); let mul_hl = _mm256_mul_epu32(x_hi, y);
let mul_hh = _mm256_mul_epu32(x_hi, y_hi); let mul_hh = _mm256_mul_epu32(x_hi, y_hi);
let res_lo0_s = shift(mul_ll); // Bignum addition
let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 32)); // Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow.
let res_lo2_s = _mm256_add_epi32(res_lo1_s, _mm256_slli_epi64(mul_hl, 32)); let mul_ll_hi = _mm256_srli_epi64::<32>(mul_ll);
let t0 = _mm256_add_epi64(mul_hl, mul_ll_hi);
// Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow.
// Also, extract high 32 bits of t0 and add to mul_hh.
let t0_lo = _mm256_and_si256(t0, _mm256_set1_epi64x(u32::MAX.into()));
let t0_hi = _mm256_srli_epi64::<32>(t0);
let t1 = _mm256_add_epi64(mul_lh, t0_lo);
let t2 = _mm256_add_epi64(mul_hh, t0_hi);
// Lastly, extract the high 32 bits of t1 and add to t2.
let t1_hi = _mm256_srli_epi64::<32>(t1);
let res_hi = _mm256_add_epi64(t2, t1_hi);
// cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on // Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high
// overflow and must be subtracted, not added. // position).
let carry0 = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s); let t1_lo = _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(t1)));
let carry1 = _mm256_cmpgt_epi64(res_lo1_s, res_lo2_s); let res_lo = _mm256_blend_epi32::<0xaa>(mul_ll, t1_lo);
let res_hi0 = mul_hh; (res_hi, res_lo)
let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 32));
let res_hi2 = _mm256_add_epi64(res_hi1, _mm256_srli_epi64(mul_hl, 32));
let res_hi3 = _mm256_sub_epi64(res_hi2, carry0);
let res_hi4 = _mm256_sub_epi64(res_hi3, carry1);
(res_hi4, res_lo2_s)
} }
/// Full 64-bit squaring. This routine is 1.2x faster than the scalar instruction. /// Full 64-bit squaring. This routine is 1.2x faster than the scalar instruction.
#[inline] #[inline]
unsafe fn square64_s(x: __m256i) -> (__m256i, __m256i) { unsafe fn square64(x: __m256i) -> (__m256i, __m256i) {
let x_hi = _mm256_srli_epi64(x, 32); // Get high 32 bits of x. See comment in mul64_64_s.
let x_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x)));
// All pairwise multiplications.
let mul_ll = _mm256_mul_epu32(x, x); let mul_ll = _mm256_mul_epu32(x, x);
let mul_lh = _mm256_mul_epu32(x, x_hi); let mul_lh = _mm256_mul_epu32(x, x_hi);
let mul_hh = _mm256_mul_epu32(x_hi, x_hi); let mul_hh = _mm256_mul_epu32(x_hi, x_hi);
let res_lo0_s = shift(mul_ll); // Bignum addition, but mul_lh is shifted by 33 bits (not 32).
let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 33)); let mul_ll_hi = _mm256_srli_epi64::<33>(mul_ll);
let t0 = _mm256_add_epi64(mul_lh, mul_ll_hi);
let t0_hi = _mm256_srli_epi64::<31>(t0);
let res_hi = _mm256_add_epi64(mul_hh, t0_hi);
// cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on // Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high
// overflow and must be subtracted, not added. // position).
let carry = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s); let mul_lh_lo = _mm256_slli_epi64::<33>(mul_lh);
let res_lo = _mm256_add_epi64(mul_ll, mul_lh_lo);
let res_hi0 = mul_hh; (res_hi, res_lo)
let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 31));
let res_hi2 = _mm256_sub_epi64(res_hi1, carry);
(res_hi2, res_lo1_s)
} }
/// Multiply two integers modulo FIELD_ORDER. /// Multiply two integers modulo FIELD_ORDER.
#[inline] #[inline]
unsafe fn mul<F: ReducibleAVX2>(x: __m256i, y: __m256i) -> __m256i { unsafe fn mul<F: ReducibleAvx2>(x: __m256i, y: __m256i) -> __m256i {
shift(F::reduce128s_s(mul64_64_s(x, y))) F::reduce128(mul64_64(x, y))
} }
/// Square an integer modulo FIELD_ORDER. /// Square an integer modulo FIELD_ORDER.
#[inline] #[inline]
unsafe fn square<F: ReducibleAVX2>(x: __m256i) -> __m256i { unsafe fn square<F: ReducibleAvx2>(x: __m256i) -> __m256i {
shift(F::reduce128s_s(square64_s(x))) F::reduce128(square64(x))
} }
#[inline] #[inline]
unsafe fn interleave0(x: __m256i, y: __m256i) -> (__m256i, __m256i) { unsafe fn interleave1(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
let a = _mm256_unpacklo_epi64(x, y); let a = _mm256_unpacklo_epi64(x, y);
let b = _mm256_unpackhi_epi64(x, y); let b = _mm256_unpackhi_epi64(x, y);
(a, b) (a, b)
} }
#[inline] #[inline]
unsafe fn interleave1(x: __m256i, y: __m256i) -> (__m256i, __m256i) { unsafe fn interleave2(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
let y_lo = _mm256_castsi256_si128(y); // This has 0 cost. let y_lo = _mm256_castsi256_si128(y); // This has 0 cost.
// 1 places y_lo in the high half of x; 0 would place it in the lower half. // 1 places y_lo in the high half of x; 0 would place it in the lower half.

View File

@ -2,8 +2,22 @@ use core::arch::x86_64::*;
use crate::field::field_types::PrimeField; use crate::field::field_types::PrimeField;
pub trait ReducibleAVX2: PrimeField { pub trait ReducibleAvx2: PrimeField {
unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i; unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i;
}
const SIGN_BIT: u64 = 1 << 63;
#[inline]
unsafe fn sign_bit() -> __m256i {
_mm256_set1_epi64x(SIGN_BIT as i64)
}
/// Add 2^63 with overflow. Needed to emulate unsigned comparisons (see point 3. in
/// packed_prime_field.rs).
#[inline]
pub unsafe fn shift(x: __m256i) -> __m256i {
_mm256_xor_si256(x, sign_bit())
} }
#[inline] #[inline]

View File

@ -2,19 +2,21 @@ use core::arch::x86_64::*;
use crate::field::goldilocks_field::GoldilocksField; use crate::field::goldilocks_field::GoldilocksField;
use crate::field::packed_avx2::common::{ use crate::field::packed_avx2::common::{
add_no_canonicalize_64_64s_s, epsilon, sub_no_canonicalize_64s_64_s, ReducibleAVX2, add_no_canonicalize_64_64s_s, epsilon, shift, sub_no_canonicalize_64s_64_s, ReducibleAvx2,
}; };
/// Reduce a u128 modulo FIELD_ORDER. The input is (u64, u64), pre-shifted by 2^63. The result is /// Reduce a u128 modulo FIELD_ORDER. The input is (u64, u64), pre-shifted by 2^63. The result is
/// similarly shifted. /// similarly shifted.
impl ReducibleAVX2 for GoldilocksField { impl ReducibleAvx2 for GoldilocksField {
#[inline] #[inline]
unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i { unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i {
let (hi0, lo0_s) = x_s; let (hi0, lo0) = x;
let lo0_s = shift(lo0);
let hi_hi0 = _mm256_srli_epi64(hi0, 32); let hi_hi0 = _mm256_srli_epi64(hi0, 32);
let lo1_s = sub_no_canonicalize_64s_64_s::<GoldilocksField>(lo0_s, hi_hi0); let lo1_s = sub_no_canonicalize_64s_64_s::<GoldilocksField>(lo0_s, hi_hi0);
let t1 = _mm256_mul_epu32(hi0, epsilon::<GoldilocksField>()); let t1 = _mm256_mul_epu32(hi0, epsilon::<GoldilocksField>());
let lo2_s = add_no_canonicalize_64_64s_s::<GoldilocksField>(t1, lo1_s); let lo2_s = add_no_canonicalize_64_64s_s::<GoldilocksField>(t1, lo1_s);
lo2_s let lo2 = shift(lo2_s);
lo2
} }
} }

View File

@ -1,21 +1,21 @@
mod avx2_prime_field;
mod common; mod common;
mod goldilocks; mod goldilocks;
mod packed_prime_field;
use packed_prime_field::PackedPrimeField; use avx2_prime_field::Avx2PrimeField;
use crate::field::goldilocks_field::GoldilocksField; use crate::field::goldilocks_field::GoldilocksField;
pub type PackedGoldilocksAVX2 = PackedPrimeField<GoldilocksField>; pub type PackedGoldilocksAvx2 = Avx2PrimeField<GoldilocksField>;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::field::goldilocks_field::GoldilocksField; use crate::field::goldilocks_field::GoldilocksField;
use crate::field::packed_avx2::common::ReducibleAVX2; use crate::field::packed_avx2::avx2_prime_field::Avx2PrimeField;
use crate::field::packed_avx2::packed_prime_field::PackedPrimeField; use crate::field::packed_avx2::common::ReducibleAvx2;
use crate::field::packed_field::PackedField; use crate::field::packed_field::PackedField;
fn test_vals_a<F: ReducibleAVX2>() -> [F; 4] { fn test_vals_a<F: ReducibleAvx2>() -> [F; 4] {
[ [
F::from_noncanonical_u64(14479013849828404771), F::from_noncanonical_u64(14479013849828404771),
F::from_noncanonical_u64(9087029921428221768), F::from_noncanonical_u64(9087029921428221768),
@ -23,7 +23,7 @@ mod tests {
F::from_noncanonical_u64(5646033492608483824), F::from_noncanonical_u64(5646033492608483824),
] ]
} }
fn test_vals_b<F: ReducibleAVX2>() -> [F; 4] { fn test_vals_b<F: ReducibleAvx2>() -> [F; 4] {
[ [
F::from_noncanonical_u64(17891926589593242302), F::from_noncanonical_u64(17891926589593242302),
F::from_noncanonical_u64(11009798273260028228), F::from_noncanonical_u64(11009798273260028228),
@ -32,17 +32,17 @@ mod tests {
] ]
} }
fn test_add<F: ReducibleAVX2>() fn test_add<F: ReducibleAvx2>()
where where
[(); PackedPrimeField::<F>::WIDTH]: , [(); Avx2PrimeField::<F>::WIDTH]:,
{ {
let a_arr = test_vals_a::<F>(); let a_arr = test_vals_a::<F>();
let b_arr = test_vals_b::<F>(); let b_arr = test_vals_b::<F>();
let packed_a = PackedPrimeField::<F>::from_arr(a_arr); let packed_a = Avx2PrimeField::<F>::from_arr(a_arr);
let packed_b = PackedPrimeField::<F>::from_arr(b_arr); let packed_b = Avx2PrimeField::<F>::from_arr(b_arr);
let packed_res = packed_a + packed_b; let packed_res = packed_a + packed_b;
let arr_res = packed_res.to_arr(); let arr_res = packed_res.as_arr();
let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a + b); let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a + b);
for (exp, res) in expected.zip(arr_res) { for (exp, res) in expected.zip(arr_res) {
@ -50,17 +50,17 @@ mod tests {
} }
} }
fn test_mul<F: ReducibleAVX2>() fn test_mul<F: ReducibleAvx2>()
where where
[(); PackedPrimeField::<F>::WIDTH]: , [(); Avx2PrimeField::<F>::WIDTH]:,
{ {
let a_arr = test_vals_a::<F>(); let a_arr = test_vals_a::<F>();
let b_arr = test_vals_b::<F>(); let b_arr = test_vals_b::<F>();
let packed_a = PackedPrimeField::<F>::from_arr(a_arr); let packed_a = Avx2PrimeField::<F>::from_arr(a_arr);
let packed_b = PackedPrimeField::<F>::from_arr(b_arr); let packed_b = Avx2PrimeField::<F>::from_arr(b_arr);
let packed_res = packed_a * packed_b; let packed_res = packed_a * packed_b;
let arr_res = packed_res.to_arr(); let arr_res = packed_res.as_arr();
let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a * b); let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a * b);
for (exp, res) in expected.zip(arr_res) { for (exp, res) in expected.zip(arr_res) {
@ -68,15 +68,15 @@ mod tests {
} }
} }
fn test_square<F: ReducibleAVX2>() fn test_square<F: ReducibleAvx2>()
where where
[(); PackedPrimeField::<F>::WIDTH]: , [(); Avx2PrimeField::<F>::WIDTH]:,
{ {
let a_arr = test_vals_a::<F>(); let a_arr = test_vals_a::<F>();
let packed_a = PackedPrimeField::<F>::from_arr(a_arr); let packed_a = Avx2PrimeField::<F>::from_arr(a_arr);
let packed_res = packed_a.square(); let packed_res = packed_a.square();
let arr_res = packed_res.to_arr(); let arr_res = packed_res.as_arr();
let expected = a_arr.iter().map(|&a| a.square()); let expected = a_arr.iter().map(|&a| a.square());
for (exp, res) in expected.zip(arr_res) { for (exp, res) in expected.zip(arr_res) {
@ -84,15 +84,15 @@ mod tests {
} }
} }
fn test_neg<F: ReducibleAVX2>() fn test_neg<F: ReducibleAvx2>()
where where
[(); PackedPrimeField::<F>::WIDTH]: , [(); Avx2PrimeField::<F>::WIDTH]:,
{ {
let a_arr = test_vals_a::<F>(); let a_arr = test_vals_a::<F>();
let packed_a = PackedPrimeField::<F>::from_arr(a_arr); let packed_a = Avx2PrimeField::<F>::from_arr(a_arr);
let packed_res = -packed_a; let packed_res = -packed_a;
let arr_res = packed_res.to_arr(); let arr_res = packed_res.as_arr();
let expected = a_arr.iter().map(|&a| -a); let expected = a_arr.iter().map(|&a| -a);
for (exp, res) in expected.zip(arr_res) { for (exp, res) in expected.zip(arr_res) {
@ -100,17 +100,17 @@ mod tests {
} }
} }
fn test_sub<F: ReducibleAVX2>() fn test_sub<F: ReducibleAvx2>()
where where
[(); PackedPrimeField::<F>::WIDTH]: , [(); Avx2PrimeField::<F>::WIDTH]:,
{ {
let a_arr = test_vals_a::<F>(); let a_arr = test_vals_a::<F>();
let b_arr = test_vals_b::<F>(); let b_arr = test_vals_b::<F>();
let packed_a = PackedPrimeField::<F>::from_arr(a_arr); let packed_a = Avx2PrimeField::<F>::from_arr(a_arr);
let packed_b = PackedPrimeField::<F>::from_arr(b_arr); let packed_b = Avx2PrimeField::<F>::from_arr(b_arr);
let packed_res = packed_a - packed_b; let packed_res = packed_a - packed_b;
let arr_res = packed_res.to_arr(); let arr_res = packed_res.as_arr();
let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a - b); let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a - b);
for (exp, res) in expected.zip(arr_res) { for (exp, res) in expected.zip(arr_res) {
@ -118,33 +118,39 @@ mod tests {
} }
} }
fn test_interleave_is_involution<F: ReducibleAVX2>() fn test_interleave_is_involution<F: ReducibleAvx2>()
where where
[(); PackedPrimeField::<F>::WIDTH]: , [(); Avx2PrimeField::<F>::WIDTH]:,
{ {
let a_arr = test_vals_a::<F>(); let a_arr = test_vals_a::<F>();
let b_arr = test_vals_b::<F>(); let b_arr = test_vals_b::<F>();
let packed_a = PackedPrimeField::<F>::from_arr(a_arr); let packed_a = Avx2PrimeField::<F>::from_arr(a_arr);
let packed_b = PackedPrimeField::<F>::from_arr(b_arr); let packed_b = Avx2PrimeField::<F>::from_arr(b_arr);
{ {
// Interleave, then deinterleave. // Interleave, then deinterleave.
let (x, y) = packed_a.interleave(packed_b, 0);
let (res_a, res_b) = x.interleave(y, 0);
assert_eq!(res_a.to_arr(), a_arr);
assert_eq!(res_b.to_arr(), b_arr);
}
{
let (x, y) = packed_a.interleave(packed_b, 1); let (x, y) = packed_a.interleave(packed_b, 1);
let (res_a, res_b) = x.interleave(y, 1); let (res_a, res_b) = x.interleave(y, 1);
assert_eq!(res_a.to_arr(), a_arr); assert_eq!(res_a.as_arr(), a_arr);
assert_eq!(res_b.to_arr(), b_arr); assert_eq!(res_b.as_arr(), b_arr);
}
{
let (x, y) = packed_a.interleave(packed_b, 2);
let (res_a, res_b) = x.interleave(y, 2);
assert_eq!(res_a.as_arr(), a_arr);
assert_eq!(res_b.as_arr(), b_arr);
}
{
let (x, y) = packed_a.interleave(packed_b, 4);
let (res_a, res_b) = x.interleave(y, 4);
assert_eq!(res_a.as_arr(), a_arr);
assert_eq!(res_b.as_arr(), b_arr);
} }
} }
fn test_interleave<F: ReducibleAVX2>() fn test_interleave<F: ReducibleAvx2>()
where where
[(); PackedPrimeField::<F>::WIDTH]: , [(); Avx2PrimeField::<F>::WIDTH]:,
{ {
let in_a: [F; 4] = [ let in_a: [F; 4] = [
F::from_noncanonical_u64(00), F::from_noncanonical_u64(00),
@ -158,42 +164,47 @@ mod tests {
F::from_noncanonical_u64(12), F::from_noncanonical_u64(12),
F::from_noncanonical_u64(13), F::from_noncanonical_u64(13),
]; ];
let int0_a: [F; 4] = [ let int1_a: [F; 4] = [
F::from_noncanonical_u64(00), F::from_noncanonical_u64(00),
F::from_noncanonical_u64(10), F::from_noncanonical_u64(10),
F::from_noncanonical_u64(02), F::from_noncanonical_u64(02),
F::from_noncanonical_u64(12), F::from_noncanonical_u64(12),
]; ];
let int0_b: [F; 4] = [ let int1_b: [F; 4] = [
F::from_noncanonical_u64(01), F::from_noncanonical_u64(01),
F::from_noncanonical_u64(11), F::from_noncanonical_u64(11),
F::from_noncanonical_u64(03), F::from_noncanonical_u64(03),
F::from_noncanonical_u64(13), F::from_noncanonical_u64(13),
]; ];
let int1_a: [F; 4] = [ let int2_a: [F; 4] = [
F::from_noncanonical_u64(00), F::from_noncanonical_u64(00),
F::from_noncanonical_u64(01), F::from_noncanonical_u64(01),
F::from_noncanonical_u64(10), F::from_noncanonical_u64(10),
F::from_noncanonical_u64(11), F::from_noncanonical_u64(11),
]; ];
let int1_b: [F; 4] = [ let int2_b: [F; 4] = [
F::from_noncanonical_u64(02), F::from_noncanonical_u64(02),
F::from_noncanonical_u64(03), F::from_noncanonical_u64(03),
F::from_noncanonical_u64(12), F::from_noncanonical_u64(12),
F::from_noncanonical_u64(13), F::from_noncanonical_u64(13),
]; ];
let packed_a = PackedPrimeField::<F>::from_arr(in_a); let packed_a = Avx2PrimeField::<F>::from_arr(in_a);
let packed_b = PackedPrimeField::<F>::from_arr(in_b); let packed_b = Avx2PrimeField::<F>::from_arr(in_b);
{
let (x0, y0) = packed_a.interleave(packed_b, 0);
assert_eq!(x0.to_arr(), int0_a);
assert_eq!(y0.to_arr(), int0_b);
}
{ {
let (x1, y1) = packed_a.interleave(packed_b, 1); let (x1, y1) = packed_a.interleave(packed_b, 1);
assert_eq!(x1.to_arr(), int1_a); assert_eq!(x1.as_arr(), int1_a);
assert_eq!(y1.to_arr(), int1_b); assert_eq!(y1.as_arr(), int1_b);
}
{
let (x2, y2) = packed_a.interleave(packed_b, 2);
assert_eq!(x2.as_arr(), int2_a);
assert_eq!(y2.as_arr(), int2_b);
}
{
let (x4, y4) = packed_a.interleave(packed_b, 4);
assert_eq!(x4.as_arr(), in_a);
assert_eq!(y4.as_arr(), in_b);
} }
} }

View File

@ -1,77 +1,81 @@
use std::fmt; use std::fmt::Debug;
use std::fmt::{Debug, Formatter};
use std::iter::{Product, Sum}; use std::iter::{Product, Sum};
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign};
use std::slice;
use crate::field::field_types::Field; use crate::field::field_types::Field;
pub trait PackedField: /// # Safety
/// - WIDTH is assumed to be a power of 2.
/// - If P implements PackedField then P must be castable to/from [P::Scalar; P::WIDTH] without UB.
pub unsafe trait PackedField:
'static 'static
+ Add<Self, Output = Self> + Add<Self, Output = Self>
+ Add<Self::FieldType, Output = Self> + Add<Self::Scalar, Output = Self>
+ AddAssign<Self> + AddAssign<Self>
+ AddAssign<Self::FieldType> + AddAssign<Self::Scalar>
+ Copy + Copy
+ Debug + Debug
+ Default + Default
// TODO: Implementing Div sounds like a pain so it's a worry for later. + From<Self::Scalar>
// TODO: Implement packed / packed division
+ Div<Self::Scalar, Output = Self>
+ Mul<Self, Output = Self> + Mul<Self, Output = Self>
+ Mul<Self::FieldType, Output = Self> + Mul<Self::Scalar, Output = Self>
+ MulAssign<Self> + MulAssign<Self>
+ MulAssign<Self::FieldType> + MulAssign<Self::Scalar>
+ Neg<Output = Self> + Neg<Output = Self>
+ Product + Product
+ Send + Send
+ Sub<Self, Output = Self> + Sub<Self, Output = Self>
+ Sub<Self::FieldType, Output = Self> + Sub<Self::Scalar, Output = Self>
+ SubAssign<Self> + SubAssign<Self>
+ SubAssign<Self::FieldType> + SubAssign<Self::Scalar>
+ Sum + Sum
+ Sync + Sync
where
Self::Scalar: Add<Self, Output = Self>,
Self::Scalar: Mul<Self, Output = Self>,
Self::Scalar: Sub<Self, Output = Self>,
{ {
type FieldType: Field; type Scalar: Field;
const LOG2_WIDTH: usize; const WIDTH: usize;
const WIDTH: usize = 1 << Self::LOG2_WIDTH; const ZERO: Self;
const ONE: Self;
fn square(&self) -> Self { fn square(&self) -> Self {
*self * *self *self * *self
} }
fn zero() -> Self { fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self;
Self::broadcast(Self::FieldType::ZERO) fn as_arr(&self) -> [Self::Scalar; Self::WIDTH];
}
fn one() -> Self {
Self::broadcast(Self::FieldType::ONE)
}
fn broadcast(x: Self::FieldType) -> Self; fn from_slice(slice: &[Self::Scalar]) -> &Self;
fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self;
fn as_slice(&self) -> &[Self::Scalar];
fn as_slice_mut(&mut self) -> &mut [Self::Scalar];
fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self; /// Take interpret two vectors as chunks of block_len elements. Unpack and interleave those
fn to_arr(&self) -> [Self::FieldType; Self::WIDTH];
fn from_slice(slice: &[Self::FieldType]) -> Self;
fn to_vec(&self) -> Vec<Self::FieldType>;
/// Take interpret two vectors as chunks of (1 << r) elements. Unpack and interleave those
/// chunks. This is best seen with an example. If we have: /// chunks. This is best seen with an example. If we have:
/// A = [x0, y0, x1, y1], /// A = [x0, y0, x1, y1],
/// B = [x2, y2, x3, y3], /// B = [x2, y2, x3, y3],
/// then /// then
/// interleave(A, B, 0) = ([x0, x2, x1, x3], [y0, y2, y1, y3]). /// interleave(A, B, 1) = ([x0, x2, x1, x3], [y0, y2, y1, y3]).
/// Pairs that were adjacent in the input are at corresponding positions in the output. /// Pairs that were adjacent in the input are at corresponding positions in the output.
/// r lets us set the size of chunks we're interleaving. If we set r = 1, then for /// r lets us set the size of chunks we're interleaving. If we set block_len = 2, then for
/// A = [x0, x1, y0, y1], /// A = [x0, x1, y0, y1],
/// B = [x2, x3, y2, y3], /// B = [x2, x3, y2, y3],
/// we obtain /// we obtain
/// interleave(A, B, r) = ([x0, x1, x2, x3], [y0, y1, y2, y3]). /// interleave(A, B, block_len) = ([x0, x1, x2, x3], [y0, y1, y2, y3]).
/// We can also think about this as stacking the vectors, dividing them into 2x2 matrices, and /// We can also think about this as stacking the vectors, dividing them into 2x2 matrices, and
/// transposing those matrices. /// transposing those matrices.
/// When r = LOG2_WIDTH, this operation is a no-op. Values of r > LOG2_WIDTH are not /// When block_len = WIDTH, this operation is a no-op. block_len must divide WIDTH. Since
/// permitted. /// WIDTH is specified to be a power of 2, block_len must also be a power of 2. It cannot be 0
fn interleave(&self, other: Self, r: usize) -> (Self, Self); /// and it cannot be > WIDTH.
fn interleave(&self, other: Self, block_len: usize) -> (Self, Self);
fn pack_slice(buf: &[Self::FieldType]) -> &[Self] { fn pack_slice(buf: &[Self::Scalar]) -> &[Self] {
assert!( assert!(
buf.len() % Self::WIDTH == 0, buf.len() % Self::WIDTH == 0,
"Slice length (got {}) must be a multiple of packed field width ({}).", "Slice length (got {}) must be a multiple of packed field width ({}).",
@ -82,7 +86,7 @@ pub trait PackedField:
let n = buf.len() / Self::WIDTH; let n = buf.len() / Self::WIDTH;
unsafe { std::slice::from_raw_parts(buf_ptr, n) } unsafe { std::slice::from_raw_parts(buf_ptr, n) }
} }
fn pack_slice_mut(buf: &mut [Self::FieldType]) -> &mut [Self] { fn pack_slice_mut(buf: &mut [Self::Scalar]) -> &mut [Self] {
assert!( assert!(
buf.len() % Self::WIDTH == 0, buf.len() % Self::WIDTH == 0,
"Slice length (got {}) must be a multiple of packed field width ({}).", "Slice length (got {}) must be a multiple of packed field width ({}).",
@ -95,143 +99,41 @@ pub trait PackedField:
} }
} }
#[derive(Copy, Clone)] unsafe impl<F: Field> PackedField for F {
#[repr(transparent)] type Scalar = Self;
pub struct Singleton<F: Field>(pub F);
impl<F: Field> Add<Self> for Singleton<F> { const WIDTH: usize = 1;
type Output = Self; const ZERO: Self = <F as Field>::ZERO;
fn add(self, rhs: Self) -> Self { const ONE: Self = <F as Field>::ONE;
Self(self.0 + rhs.0)
}
}
impl<F: Field> Add<F> for Singleton<F> {
type Output = Self;
fn add(self, rhs: F) -> Self {
self + Self::broadcast(rhs)
}
}
impl<F: Field> AddAssign<Self> for Singleton<F> {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl<F: Field> AddAssign<F> for Singleton<F> {
fn add_assign(&mut self, rhs: F) {
*self = *self + rhs;
}
}
impl<F: Field> Debug for Singleton<F> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "({:?})", self.0)
}
}
impl<F: Field> Default for Singleton<F> {
fn default() -> Self {
Self::zero()
}
}
impl<F: Field> Mul<Self> for Singleton<F> {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
Self(self.0 * rhs.0)
}
}
impl<F: Field> Mul<F> for Singleton<F> {
type Output = Self;
fn mul(self, rhs: F) -> Self {
self * Self::broadcast(rhs)
}
}
impl<F: Field> MulAssign<Self> for Singleton<F> {
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl<F: Field> MulAssign<F> for Singleton<F> {
fn mul_assign(&mut self, rhs: F) {
*self = *self * rhs;
}
}
impl<F: Field> Neg for Singleton<F> {
type Output = Self;
fn neg(self) -> Self {
Self(-self.0)
}
}
impl<F: Field> Product for Singleton<F> {
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
Self(iter.map(|x| x.0).product())
}
}
impl<F: Field> PackedField for Singleton<F> {
const LOG2_WIDTH: usize = 0;
type FieldType = F;
fn broadcast(x: F) -> Self {
Self(x)
}
fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self {
Self(arr[0])
}
fn to_arr(&self) -> [Self::FieldType; Self::WIDTH] {
[self.0]
}
fn from_slice(slice: &[Self::FieldType]) -> Self {
assert!(slice.len() == 1);
Self(slice[0])
}
fn to_vec(&self) -> Vec<Self::FieldType> {
vec![self.0]
}
fn interleave(&self, other: Self, r: usize) -> (Self, Self) {
match r {
0 => (*self, other), // This is a no-op whenever r == LOG2_WIDTH.
_ => panic!("r cannot be more than LOG2_WIDTH"),
}
}
fn square(&self) -> Self { fn square(&self) -> Self {
Self(self.0.square()) <Self as Field>::square(self)
}
} }
impl<F: Field> Sub<Self> for Singleton<F> { fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self {
type Output = Self; arr[0]
fn sub(self, rhs: Self) -> Self {
Self(self.0 - rhs.0)
}
}
impl<F: Field> Sub<F> for Singleton<F> {
type Output = Self;
fn sub(self, rhs: F) -> Self {
self - Self::broadcast(rhs)
}
}
impl<F: Field> SubAssign<Self> for Singleton<F> {
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl<F: Field> SubAssign<F> for Singleton<F> {
fn sub_assign(&mut self, rhs: F) {
*self = *self - rhs;
} }
fn as_arr(&self) -> [Self::Scalar; Self::WIDTH] {
[*self]
} }
impl<F: Field> Sum for Singleton<F> { fn from_slice(slice: &[Self::Scalar]) -> &Self {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self { &slice[0]
Self(iter.map(|x| x.0).sum()) }
fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self {
&mut slice[0]
}
fn as_slice(&self) -> &[Self::Scalar] {
slice::from_ref(self)
}
fn as_slice_mut(&mut self) -> &mut [Self::Scalar] {
slice::from_mut(self)
}
fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) {
match block_len {
1 => (*self, other),
_ => panic!("unsupported block length"),
}
} }
} }

View File

@ -24,7 +24,7 @@ where
ExpectedOp: Fn(u64) -> u64, ExpectedOp: Fn(u64) -> u64,
{ {
let inputs = test_inputs(F::ORDER); let inputs = test_inputs(F::ORDER);
let expected: Vec<_> = inputs.iter().map(|x| expected_op(x.clone())).collect(); let expected: Vec<_> = inputs.iter().map(|&x| expected_op(x)).collect();
let output: Vec<_> = inputs let output: Vec<_> = inputs
.iter() .iter()
.cloned() .cloned()
@ -144,7 +144,7 @@ macro_rules! test_prime_field_arithmetic {
fn inverse_2exp() { fn inverse_2exp() {
type F = $field; type F = $field;
let v = <F as Field>::PrimeField::TWO_ADICITY; let v = <F as Field>::TWO_ADICITY;
for e in [0, 1, 2, 3, 4, v - 2, v - 1, v, v + 1, v + 2, 123 * v] { for e in [0, 1, 2, 3, 4, v - 2, v - 1, v, v + 1, v + 2, 123 * v] {
let x = F::TWO.exp_u64(e as u64); let x = F::TWO.exp_u64(e as u64);

View File

@ -1,4 +1,3 @@
use std::convert::TryInto;
use std::fmt; use std::fmt;
use std::fmt::{Debug, Display, Formatter}; use std::fmt::{Debug, Display, Formatter};
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
@ -12,7 +11,6 @@ use rand::Rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField;
/// The base field of the secp256k1 elliptic curve. /// The base field of the secp256k1 elliptic curve.
/// ///
@ -36,8 +34,80 @@ fn biguint_from_array(arr: [u64; 4]) -> BigUint {
]) ])
} }
impl Secp256K1Base { impl Default for Secp256K1Base {
fn to_canonical_biguint(&self) -> BigUint { fn default() -> Self {
Self::ZERO
}
}
impl PartialEq for Secp256K1Base {
fn eq(&self, other: &Self) -> bool {
self.to_biguint() == other.to_biguint()
}
}
impl Eq for Secp256K1Base {}
impl Hash for Secp256K1Base {
fn hash<H: Hasher>(&self, state: &mut H) {
self.to_biguint().hash(state)
}
}
impl Display for Secp256K1Base {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&self.to_biguint(), f)
}
}
impl Debug for Secp256K1Base {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Debug::fmt(&self.to_biguint(), f)
}
}
impl Field for Secp256K1Base {
const ZERO: Self = Self([0; 4]);
const ONE: Self = Self([1, 0, 0, 0]);
const TWO: Self = Self([2, 0, 0, 0]);
const NEG_ONE: Self = Self([
0xFFFFFFFEFFFFFC2E,
0xFFFFFFFFFFFFFFFF,
0xFFFFFFFFFFFFFFFF,
0xFFFFFFFFFFFFFFFF,
]);
const TWO_ADICITY: usize = 1;
const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY;
// Sage: `g = GF(p).multiplicative_generator()`
const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([5, 0, 0, 0]);
// Sage: `g_2 = g^((p - 1) / 2)`
const POWER_OF_TWO_GENERATOR: Self = Self::NEG_ONE;
const BITS: usize = 256;
fn order() -> BigUint {
BigUint::from_slice(&[
0xFFFFFC2F, 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0xFFFFFFFF,
])
}
fn characteristic() -> BigUint {
Self::order()
}
fn try_inverse(&self) -> Option<Self> {
if self.is_zero() {
return None;
}
// Fermat's Little Theorem
Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one())))
}
fn to_biguint(&self) -> BigUint {
let mut result = biguint_from_array(self.0); let mut result = biguint_from_array(self.0);
if result >= Self::order() { if result >= Self::order() {
result -= Self::order(); result -= Self::order();
@ -55,79 +125,6 @@ impl Secp256K1Base {
.expect("error converting to u64 array"), .expect("error converting to u64 array"),
) )
} }
}
impl Default for Secp256K1Base {
fn default() -> Self {
Self::ZERO
}
}
impl PartialEq for Secp256K1Base {
fn eq(&self, other: &Self) -> bool {
self.to_canonical_biguint() == other.to_canonical_biguint()
}
}
impl Eq for Secp256K1Base {}
impl Hash for Secp256K1Base {
fn hash<H: Hasher>(&self, state: &mut H) {
self.to_canonical_biguint().hash(state)
}
}
impl Display for Secp256K1Base {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&self.to_canonical_biguint(), f)
}
}
impl Debug for Secp256K1Base {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Debug::fmt(&self.to_canonical_biguint(), f)
}
}
impl Field for Secp256K1Base {
// TODO: fix
type PrimeField = GoldilocksField;
const ZERO: Self = Self([0; 4]);
const ONE: Self = Self([1, 0, 0, 0]);
const TWO: Self = Self([2, 0, 0, 0]);
const NEG_ONE: Self = Self([
0xFFFFFFFEFFFFFC2E,
0xFFFFFFFFFFFFFFFF,
0xFFFFFFFFFFFFFFFF,
0xFFFFFFFFFFFFFFFF,
]);
// TODO: fix
const CHARACTERISTIC: u64 = 0;
const TWO_ADICITY: usize = 1;
// Sage: `g = GF(p).multiplicative_generator()`
const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([5, 0, 0, 0]);
// Sage: `g_2 = g^((p - 1) / 2)`
const POWER_OF_TWO_GENERATOR: Self = Self::NEG_ONE;
fn order() -> BigUint {
BigUint::from_slice(&[
0xFFFFFC2F, 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0xFFFFFFFF,
])
}
fn try_inverse(&self) -> Option<Self> {
if self.is_zero() {
return None;
}
// Fermat's Little Theorem
Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one())))
}
#[inline] #[inline]
fn from_canonical_u64(n: u64) -> Self { fn from_canonical_u64(n: u64) -> Self {
@ -157,7 +154,7 @@ impl Neg for Secp256K1Base {
if self.is_zero() { if self.is_zero() {
Self::ZERO Self::ZERO
} else { } else {
Self::from_biguint(Self::order() - self.to_canonical_biguint()) Self::from_biguint(Self::order() - self.to_biguint())
} }
} }
} }
@ -167,7 +164,7 @@ impl Add for Secp256K1Base {
#[inline] #[inline]
fn add(self, rhs: Self) -> Self { fn add(self, rhs: Self) -> Self {
let mut result = self.to_canonical_biguint() + rhs.to_canonical_biguint(); let mut result = self.to_biguint() + rhs.to_biguint();
if result >= Self::order() { if result >= Self::order() {
result -= Self::order(); result -= Self::order();
} }
@ -210,9 +207,7 @@ impl Mul for Secp256K1Base {
#[inline] #[inline]
fn mul(self, rhs: Self) -> Self { fn mul(self, rhs: Self) -> Self {
Self::from_biguint( Self::from_biguint((self.to_biguint() * rhs.to_biguint()).mod_floor(&Self::order()))
(self.to_canonical_biguint() * rhs.to_canonical_biguint()).mod_floor(&Self::order()),
)
} }
} }
@ -244,3 +239,10 @@ impl DivAssign for Secp256K1Base {
*self = *self / rhs; *self = *self / rhs;
} }
} }
#[cfg(test)]
mod tests {
use crate::test_field_arithmetic;
test_field_arithmetic!(crate::field::secp256k1_base::Secp256K1Base);
}

View File

@ -0,0 +1,257 @@
use std::convert::TryInto;
use std::fmt;
use std::fmt::{Debug, Display, Formatter};
use std::hash::{Hash, Hasher};
use std::iter::{Product, Sum};
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use itertools::Itertools;
use num::bigint::{BigUint, RandBigInt};
use num::{Integer, One};
use rand::Rng;
use serde::{Deserialize, Serialize};
use crate::field::field_types::Field;
/// The base field of the secp256k1 elliptic curve.
///
/// Its order is
/// ```ignore
/// P = 0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE BAAEDCE6 AF48A03B BFD25E8C D0364141
/// = 115792089237316195423570985008687907852837564279074904382605163141518161494337
/// = 2**256 - 432420386565659656852420866394968145599
/// ```
#[derive(Copy, Clone, Serialize, Deserialize)]
pub struct Secp256K1Scalar(pub [u64; 4]);
fn biguint_from_array(arr: [u64; 4]) -> BigUint {
BigUint::from_slice(&[
arr[0] as u32,
(arr[0] >> 32) as u32,
arr[1] as u32,
(arr[1] >> 32) as u32,
arr[2] as u32,
(arr[2] >> 32) as u32,
arr[3] as u32,
(arr[3] >> 32) as u32,
])
}
impl Default for Secp256K1Scalar {
fn default() -> Self {
Self::ZERO
}
}
impl PartialEq for Secp256K1Scalar {
fn eq(&self, other: &Self) -> bool {
self.to_biguint() == other.to_biguint()
}
}
impl Eq for Secp256K1Scalar {}
impl Hash for Secp256K1Scalar {
fn hash<H: Hasher>(&self, state: &mut H) {
self.to_biguint().hash(state)
}
}
impl Display for Secp256K1Scalar {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&self.to_biguint(), f)
}
}
impl Debug for Secp256K1Scalar {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Debug::fmt(&self.to_biguint(), f)
}
}
impl Field for Secp256K1Scalar {
const ZERO: Self = Self([0; 4]);
const ONE: Self = Self([1, 0, 0, 0]);
const TWO: Self = Self([2, 0, 0, 0]);
const NEG_ONE: Self = Self([
0xBFD25E8CD0364140,
0xBAAEDCE6AF48A03B,
0xFFFFFFFFFFFFFFFE,
0xFFFFFFFFFFFFFFFF,
]);
const TWO_ADICITY: usize = 6;
const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY;
// Sage: `g = GF(p).multiplicative_generator()`
const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([7, 0, 0, 0]);
// Sage: `g_2 = power_mod(g, (p - 1) // 2^6), p)`
// 5480320495727936603795231718619559942670027629901634955707709633242980176626
const POWER_OF_TWO_GENERATOR: Self = Self([
0x992f4b5402b052f2,
0x98BDEAB680756045,
0xDF9879A3FBC483A8,
0xC1DC060E7A91986,
]);
const BITS: usize = 256;
fn order() -> BigUint {
BigUint::from_slice(&[
0xD0364141, 0xBFD25E8C, 0xAF48A03B, 0xBAAEDCE6, 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFFFF,
0xFFFFFFFF,
])
}
fn characteristic() -> BigUint {
Self::order()
}
fn try_inverse(&self) -> Option<Self> {
if self.is_zero() {
return None;
}
// Fermat's Little Theorem
Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one())))
}
fn to_biguint(&self) -> BigUint {
let mut result = biguint_from_array(self.0);
if result >= Self::order() {
result -= Self::order();
}
result
}
fn from_biguint(val: BigUint) -> Self {
Self(
val.to_u64_digits()
.into_iter()
.pad_using(4, |_| 0)
.collect::<Vec<_>>()[..]
.try_into()
.expect("error converting to u64 array"),
)
}
#[inline]
fn from_canonical_u64(n: u64) -> Self {
Self([n, 0, 0, 0])
}
#[inline]
fn from_noncanonical_u128(n: u128) -> Self {
Self([n as u64, (n >> 64) as u64, 0, 0])
}
#[inline]
fn from_noncanonical_u96(n: (u64, u32)) -> Self {
Self([n.0, n.1 as u64, 0, 0])
}
fn rand_from_rng<R: Rng>(rng: &mut R) -> Self {
Self::from_biguint(rng.gen_biguint_below(&Self::order()))
}
}
impl Neg for Secp256K1Scalar {
type Output = Self;
#[inline]
fn neg(self) -> Self {
if self.is_zero() {
Self::ZERO
} else {
Self::from_biguint(Self::order() - self.to_biguint())
}
}
}
impl Add for Secp256K1Scalar {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self {
let mut result = self.to_biguint() + rhs.to_biguint();
if result >= Self::order() {
result -= Self::order();
}
Self::from_biguint(result)
}
}
impl AddAssign for Secp256K1Scalar {
#[inline]
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl Sum for Secp256K1Scalar {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Self::ZERO, |acc, x| acc + x)
}
}
impl Sub for Secp256K1Scalar {
type Output = Self;
#[inline]
#[allow(clippy::suspicious_arithmetic_impl)]
fn sub(self, rhs: Self) -> Self {
self + -rhs
}
}
impl SubAssign for Secp256K1Scalar {
#[inline]
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl Mul for Secp256K1Scalar {
type Output = Self;
#[inline]
fn mul(self, rhs: Self) -> Self {
Self::from_biguint((self.to_biguint() * rhs.to_biguint()).mod_floor(&Self::order()))
}
}
impl MulAssign for Secp256K1Scalar {
#[inline]
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl Product for Secp256K1Scalar {
#[inline]
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.reduce(|acc, x| acc * x).unwrap_or(Self::ONE)
}
}
impl Div for Secp256K1Scalar {
type Output = Self;
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, rhs: Self) -> Self::Output {
self * rhs.inverse()
}
}
impl DivAssign for Secp256K1Scalar {
fn div_assign(&mut self, rhs: Self) {
*self = *self / rhs;
}
}
#[cfg(test)]
mod tests {
use crate::test_field_arithmetic;
test_field_arithmetic!(crate::field::secp256k1_scalar::Secp256K1Scalar);
}

View File

@ -11,14 +11,14 @@ use crate::plonk::circuit_data::CommonCircuitData;
use crate::plonk::config::GenericConfig; use crate::plonk::config::GenericConfig;
use crate::plonk::plonk_common::PlonkPolynomials; use crate::plonk::plonk_common::PlonkPolynomials;
use crate::plonk::proof::OpeningSet; use crate::plonk::proof::OpeningSet;
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::timed; use crate::timed;
use crate::util::reducing::ReducingFactor; use crate::util::reducing::ReducingFactor;
use crate::util::timing::TimingTree; use crate::util::timing::TimingTree;
use crate::util::{log2_strict, reverse_bits, reverse_index_bits_in_place, transpose}; use crate::util::{log2_strict, reverse_bits, reverse_index_bits_in_place, transpose};
/// Two (~64 bit) field elements gives ~128 bit security. /// Four (~64 bit) field elements gives ~128 bit security.
pub const SALT_SIZE: usize = 2; pub const SALT_SIZE: usize = 4;
/// Represents a batch FRI based commitment to a list of polynomials. /// Represents a batch FRI based commitment to a list of polynomials.
pub struct PolynomialBatchCommitment<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize> { pub struct PolynomialBatchCommitment<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize> {

View File

@ -36,8 +36,4 @@ impl FriParams {
pub(crate) fn max_arity_bits(&self) -> Option<usize> { pub(crate) fn max_arity_bits(&self) -> Option<usize> {
self.reduction_arity_bits.iter().copied().max() self.reduction_arity_bits.iter().copied().max()
} }
pub(crate) fn max_arity(&self) -> Option<usize> {
self.max_arity_bits().map(|bits| 1 << bits)
}
} }

View File

@ -16,7 +16,7 @@ use crate::plonk::circuit_data::CommonCircuitData;
use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::config::{GenericConfig, Hasher};
use crate::plonk::plonk_common::PolynomialsIndexBlinding; use crate::plonk::plonk_common::PolynomialsIndexBlinding;
use crate::plonk::proof::{FriInferredElements, ProofChallenges}; use crate::plonk::proof::{FriInferredElements, ProofChallenges};
use crate::polynomial::polynomial::PolynomialCoeffs; use crate::polynomial::PolynomialCoeffs;
/// Evaluations and Merkle proof produced by the prover in a FRI query step. /// Evaluations and Merkle proof produced by the prover in a FRI query step.
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]

View File

@ -9,7 +9,7 @@ use crate::iop::challenger::Challenger;
use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::circuit_data::CommonCircuitData;
use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::config::{GenericConfig, Hasher};
use crate::plonk::plonk_common::reduce_with_powers; use crate::plonk::plonk_common::reduce_with_powers;
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::timed; use crate::timed;
use crate::util::reverse_index_bits_in_place; use crate::util::reverse_index_bits_in_place;
use crate::util::timing::TimingTree; use crate::util::timing::TimingTree;

View File

@ -3,13 +3,16 @@ use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, RichField}; use crate::field::field_types::{Field, RichField};
use crate::fri::proof::{FriInitialTreeProofTarget, FriProofTarget, FriQueryRoundTarget}; use crate::fri::proof::{FriInitialTreeProofTarget, FriProofTarget, FriQueryRoundTarget};
use crate::fri::FriConfig; use crate::fri::FriConfig;
use crate::gadgets::interpolation::InterpolationGate;
use crate::gates::gate::Gate; use crate::gates::gate::Gate;
use crate::gates::interpolation::InterpolationGate; use crate::gates::interpolation::HighDegreeInterpolationGate;
use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate;
use crate::gates::random_access::RandomAccessGate; use crate::gates::random_access::RandomAccessGate;
use crate::hash::hash_types::MerkleCapTarget; use crate::hash::hash_types::MerkleCapTarget;
use crate::iop::challenger::RecursiveChallenger; use crate::iop::challenger::RecursiveChallenger;
use crate::iop::target::{BoolTarget, Target}; use crate::iop::target::{BoolTarget, Target};
use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData};
use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::circuit_data::CommonCircuitData;
use crate::plonk::config::{AlgebraicConfig, AlgebraicHasher, GenericConfig}; use crate::plonk::config::{AlgebraicConfig, AlgebraicHasher, GenericConfig};
use crate::plonk::plonk_common::PlonkPolynomials; use crate::plonk::plonk_common::PlonkPolynomials;
@ -28,6 +31,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
arity_bits: usize, arity_bits: usize,
evals: &[ExtensionTarget<D>], evals: &[ExtensionTarget<D>],
beta: ExtensionTarget<D>, beta: ExtensionTarget<D>,
common_data: &CommonCircuitData<F, D>,
) -> ExtensionTarget<D> { ) -> ExtensionTarget<D> {
let arity = 1 << arity_bits; let arity = 1 << arity_bits;
debug_assert_eq!(evals.len(), arity); debug_assert_eq!(evals.len(), arity);
@ -44,37 +48,62 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let coset_start = self.mul(start, x); let coset_start = self.mul(start, x);
// The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta. // The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta.
self.interpolate_coset(arity_bits, coset_start, &evals, beta) // `HighDegreeInterpolationGate` has degree `arity`, so we use the low-degree gate if
// the arity is too large.
if arity > common_data.quotient_degree_factor {
self.interpolate_coset::<LowDegreeInterpolationGate<F, D>>(
arity_bits,
coset_start,
&evals,
beta,
)
} else {
self.interpolate_coset::<HighDegreeInterpolationGate<F, D>>(
arity_bits,
coset_start,
&evals,
beta,
)
}
} }
/// Make sure we have enough wires and routed wires to do the FRI checks efficiently. This check /// Make sure we have enough wires and routed wires to do the FRI checks efficiently. This check
/// isn't required -- without it we'd get errors elsewhere in the stack -- but just gives more /// isn't required -- without it we'd get errors elsewhere in the stack -- but just gives more
/// helpful errors. /// helpful errors.
fn check_recursion_config(&self, max_fri_arity: usize) { fn check_recursion_config(
&self,
max_fri_arity_bits: usize,
common_data: &CommonCircuitData<F, D>,
) {
let random_access = RandomAccessGate::<F, D>::new_from_config( let random_access = RandomAccessGate::<F, D>::new_from_config(
&self.config, &self.config,
max_fri_arity.max(1 << self.config.cap_height), max_fri_arity_bits.max(self.config.cap_height),
); );
let interpolation_gate = InterpolationGate::<F, D>::new(log2_strict(max_fri_arity)); let (interpolation_wires, interpolation_routed_wires) =
if 1 << max_fri_arity_bits > common_data.quotient_degree_factor {
let gate = LowDegreeInterpolationGate::<F, D>::new(max_fri_arity_bits);
(gate.num_wires(), gate.num_routed_wires())
} else {
let gate = HighDegreeInterpolationGate::<F, D>::new(max_fri_arity_bits);
(gate.num_wires(), gate.num_routed_wires())
};
let min_wires = random_access let min_wires = random_access.num_wires().max(interpolation_wires);
.num_wires()
.max(interpolation_gate.num_wires());
let min_routed_wires = random_access let min_routed_wires = random_access
.num_routed_wires() .num_routed_wires()
.max(interpolation_gate.num_routed_wires()); .max(interpolation_routed_wires);
assert!( assert!(
self.config.num_wires >= min_wires, self.config.num_wires >= min_wires,
"To efficiently perform FRI checks with an arity of {}, at least {} wires are needed. Consider reducing arity.", "To efficiently perform FRI checks with an arity of 2^{}, at least {} wires are needed. Consider reducing arity.",
max_fri_arity, max_fri_arity_bits,
min_wires min_wires
); );
assert!( assert!(
self.config.num_routed_wires >= min_routed_wires, self.config.num_routed_wires >= min_routed_wires,
"To efficiently perform FRI checks with an arity of {}, at least {} routed wires are needed. Consider reducing arity.", "To efficiently perform FRI checks with an arity of 2^{}, at least {} routed wires are needed. Consider reducing arity.",
max_fri_arity, max_fri_arity_bits,
min_routed_wires min_routed_wires
); );
} }
@ -108,8 +137,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
) { ) {
let config = &common_data.config; let config = &common_data.config;
if let Some(max_arity) = common_data.fri_params.max_arity() { if let Some(max_arity_bits) = common_data.fri_params.max_arity_bits() {
self.check_recursion_config(max_arity); self.check_recursion_config(max_arity_bits, common_data);
} }
debug_assert_eq!( debug_assert_eq!(
@ -233,7 +262,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
common_data: &CommonCircuitData<F, C, D>, common_data: &CommonCircuitData<F, C, D>,
) -> ExtensionTarget<D> { ) -> ExtensionTarget<D> {
assert!(D > 1, "Not implemented for D=1."); assert!(D > 1, "Not implemented for D=1.");
let config = self.config.clone(); let config = &common_data.config;
let degree_log = common_data.degree_bits; let degree_log = common_data.degree_bits;
debug_assert_eq!( debug_assert_eq!(
degree_log, degree_log,
@ -306,9 +335,13 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
common_data: &CommonCircuitData<F, C, D>, common_data: &CommonCircuitData<F, C, D>,
) { ) {
let n_log = log2_strict(n); let n_log = log2_strict(n);
// TODO: Do we need to range check `x_index` to a target smaller than `p`?
// Note that this `low_bits` decomposition permits non-canonical binary encodings. Here we
// verify that this has a negligible impact on soundness error.
Self::assert_noncanonical_indices_ok(&common_data.config);
let x_index = challenger.get_challenge(self); let x_index = challenger.get_challenge(self);
let mut x_index_bits = self.low_bits(x_index, n_log, 64); let mut x_index_bits = self.low_bits(x_index, n_log, F::BITS);
let cap_index = let cap_index =
self.le_sum(x_index_bits[x_index_bits.len() - common_data.config.cap_height..].iter()); self.le_sum(x_index_bits[x_index_bits.len() - common_data.config.cap_height..].iter());
with_context!( with_context!(
@ -376,6 +409,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
arity_bits, arity_bits,
evals, evals,
betas[i], betas[i],
common_data
) )
); );
@ -409,6 +443,26 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
); );
self.connect_extension(eval, old_eval); self.connect_extension(eval, old_eval);
} }
/// We decompose FRI query indices into bits without verifying that the decomposition given by
/// the prover is the canonical one. In particular, if `x_index < 2^field_bits - p`, then the
/// prover could supply the binary encoding of either `x_index` or `x_index + p`, since the are
/// congruent mod `p`. However, this only occurs with probability
/// p_ambiguous = (2^field_bits - p) / p
/// which is small for the field that we use in practice.
///
/// In particular, the soundness error of one FRI query is roughly the codeword rate, which
/// is much larger than this ambiguous-element probability given any reasonable parameters.
/// Thus ambiguous elements contribute a negligible amount to soundness error.
///
/// Here we compare the probabilities as a sanity check, to verify the claim above.
fn assert_noncanonical_indices_ok(config: &CircuitConfig) {
let num_ambiguous_elems = u64::MAX - F::ORDER + 1;
let query_error = config.rate();
let p_ambiguous = (num_ambiguous_elems as f64) / (F::ORDER as f64);
assert!(p_ambiguous < query_error * 1e-5,
"A non-negligible portion of field elements are in the range that permits non-canonical encodings. Need to do more analysis or enforce canonical encodings.");
}
} }
#[derive(Copy, Clone)] #[derive(Copy, Clone)]

View File

@ -2,6 +2,8 @@ use std::borrow::Borrow;
use crate::field::extension_field::Extendable; use crate::field::extension_field::Extendable;
use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gates::arithmetic::ArithmeticExtensionGate;
use crate::field::field_types::{PrimeField, RichField};
use crate::gates::arithmetic_base::ArithmeticGate;
use crate::gates::exponentiation::ExponentiationGate; use crate::gates::exponentiation::ExponentiationGate;
use crate::iop::target::{BoolTarget, Target}; use crate::iop::target::{BoolTarget, Target};
use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_builder::CircuitBuilder;
@ -32,18 +34,117 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
multiplicand_1: Target, multiplicand_1: Target,
addend: Target, addend: Target,
) -> Target { ) -> Target {
// If we're not configured to use the base arithmetic gate, just call arithmetic_extension.
if !self.config.use_base_arithmetic_gate {
let multiplicand_0_ext = self.convert_to_ext(multiplicand_0); let multiplicand_0_ext = self.convert_to_ext(multiplicand_0);
let multiplicand_1_ext = self.convert_to_ext(multiplicand_1); let multiplicand_1_ext = self.convert_to_ext(multiplicand_1);
let addend_ext = self.convert_to_ext(addend); let addend_ext = self.convert_to_ext(addend);
self.arithmetic_extension( return self
.arithmetic_extension(
const_0, const_0,
const_1, const_1,
multiplicand_0_ext, multiplicand_0_ext,
multiplicand_1_ext, multiplicand_1_ext,
addend_ext, addend_ext,
) )
.0[0] .0[0];
}
// See if we can determine the result without adding an `ArithmeticGate`.
if let Some(result) =
self.arithmetic_special_cases(const_0, const_1, multiplicand_0, multiplicand_1, addend)
{
return result;
}
// See if we've already computed the same operation.
let operation = BaseArithmeticOperation {
const_0,
const_1,
multiplicand_0,
multiplicand_1,
addend,
};
if let Some(&result) = self.base_arithmetic_results.get(&operation) {
return result;
}
// Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot.
let result = self.add_base_arithmetic_operation(operation);
self.base_arithmetic_results.insert(operation, result);
result
}
fn add_base_arithmetic_operation(&mut self, operation: BaseArithmeticOperation<F>) -> Target {
let (gate, i) = self.find_base_arithmetic_gate(operation.const_0, operation.const_1);
let wires_multiplicand_0 = Target::wire(gate, ArithmeticGate::wire_ith_multiplicand_0(i));
let wires_multiplicand_1 = Target::wire(gate, ArithmeticGate::wire_ith_multiplicand_1(i));
let wires_addend = Target::wire(gate, ArithmeticGate::wire_ith_addend(i));
self.connect(operation.multiplicand_0, wires_multiplicand_0);
self.connect(operation.multiplicand_1, wires_multiplicand_1);
self.connect(operation.addend, wires_addend);
Target::wire(gate, ArithmeticGate::wire_ith_output(i))
}
/// Checks for special cases where the value of
/// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`
/// can be determined without adding an `ArithmeticGate`.
fn arithmetic_special_cases(
&mut self,
const_0: F,
const_1: F,
multiplicand_0: Target,
multiplicand_1: Target,
addend: Target,
) -> Option<Target> {
let zero = self.zero();
let mul_0_const = self.target_as_constant(multiplicand_0);
let mul_1_const = self.target_as_constant(multiplicand_1);
let addend_const = self.target_as_constant(addend);
let first_term_zero =
const_0 == F::ZERO || multiplicand_0 == zero || multiplicand_1 == zero;
let second_term_zero = const_1 == F::ZERO || addend == zero;
// If both terms are constant, return their (constant) sum.
let first_term_const = if first_term_zero {
Some(F::ZERO)
} else if let (Some(x), Some(y)) = (mul_0_const, mul_1_const) {
Some(x * y * const_0)
} else {
None
};
let second_term_const = if second_term_zero {
Some(F::ZERO)
} else {
addend_const.map(|x| x * const_1)
};
if let (Some(x), Some(y)) = (first_term_const, second_term_const) {
return Some(self.constant(x + y));
}
if first_term_zero && const_1.is_one() {
return Some(addend);
}
if second_term_zero {
if let Some(x) = mul_0_const {
if (x * const_0).is_one() {
return Some(multiplicand_1);
}
}
if let Some(x) = mul_1_const {
if (x * const_0).is_one() {
return Some(multiplicand_0);
}
}
}
None
} }
/// Computes `x * y + z`. /// Computes `x * y + z`.
@ -53,20 +154,20 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Computes `x + C`. /// Computes `x + C`.
pub fn add_const(&mut self, x: Target, c: F) -> Target { pub fn add_const(&mut self, x: Target, c: F) -> Target {
let one = self.one(); let c = self.constant(c);
self.arithmetic(F::ONE, c, one, x, one) self.add(x, c)
} }
/// Computes `C * x`. /// Computes `C * x`.
pub fn mul_const(&mut self, c: F, x: Target) -> Target { pub fn mul_const(&mut self, c: F, x: Target) -> Target {
let zero = self.zero(); let c = self.constant(c);
self.mul_const_add(c, x, zero) self.mul(c, x)
} }
/// Computes `C * x + y`. /// Computes `C * x + y`.
pub fn mul_const_add(&mut self, c: F, x: Target, y: Target) -> Target { pub fn mul_const_add(&mut self, c: F, x: Target, y: Target) -> Target {
let one = self.one(); let c = self.constant(c);
self.arithmetic(c, F::ONE, x, one, y) self.mul_add(c, x, y)
} }
/// Computes `x * y - z`. /// Computes `x * y - z`.
@ -82,13 +183,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
} }
/// Add `n` `Target`s. /// Add `n` `Target`s.
// TODO: Can be made `D` times more efficient by using all wires of an `ArithmeticExtensionGate`.
pub fn add_many(&mut self, terms: &[Target]) -> Target { pub fn add_many(&mut self, terms: &[Target]) -> Target {
let terms_ext = terms terms.iter().fold(self.zero(), |acc, &t| self.add(acc, t))
.iter()
.map(|&t| self.convert_to_ext(t))
.collect::<Vec<_>>();
self.add_many_extension(&terms_ext).to_target_array()[0]
} }
/// Computes `x - y`. /// Computes `x - y`.
@ -106,16 +202,16 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Multiply `n` `Target`s. /// Multiply `n` `Target`s.
pub fn mul_many(&mut self, terms: &[Target]) -> Target { pub fn mul_many(&mut self, terms: &[Target]) -> Target {
let terms_ext = terms terms
.iter() .iter()
.map(|&t| self.convert_to_ext(t)) .copied()
.collect::<Vec<_>>(); .reduce(|acc, t| self.mul(acc, t))
self.mul_many_extension(&terms_ext).to_target_array()[0] .unwrap_or_else(|| self.one())
} }
/// Exponentiate `base` to the power of `2^power_log`. /// Exponentiate `base` to the power of `2^power_log`.
pub fn exp_power_of_2(&mut self, base: Target, power_log: usize) -> Target { pub fn exp_power_of_2(&mut self, base: Target, power_log: usize) -> Target {
if power_log > ArithmeticExtensionGate::<D>::new_from_config(&self.config).num_ops { if power_log > self.num_base_arithmetic_ops_per_gate() {
// Cheaper to just use `ExponentiateGate`. // Cheaper to just use `ExponentiateGate`.
return self.exp_u64(base, 1 << power_log); return self.exp_u64(base, 1 << power_log);
} }
@ -169,8 +265,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let base_t = self.constant(base); let base_t = self.constant(base);
let exponent_bits: Vec<_> = exponent_bits.into_iter().map(|b| *b.borrow()).collect(); let exponent_bits: Vec<_> = exponent_bits.into_iter().map(|b| *b.borrow()).collect();
if exponent_bits.len() > ArithmeticExtensionGate::<D>::new_from_config(&self.config).num_ops if exponent_bits.len() > self.num_base_arithmetic_ops_per_gate() {
{
// Cheaper to just use `ExponentiateGate`. // Cheaper to just use `ExponentiateGate`.
return self.exp_from_bits(base_t, exponent_bits); return self.exp_from_bits(base_t, exponent_bits);
} }
@ -220,3 +315,13 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.inverse_extension(x_ext).0[0] self.inverse_extension(x_ext).0[0]
} }
} }
/// Represents a base arithmetic operation in the circuit. Used to memoize results.
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
pub(crate) struct BaseArithmeticOperation<F: PrimeField> {
const_0: F,
const_1: F,
multiplicand_0: Target,
multiplicand_1: Target,
addend: Target,
}

View File

@ -1,8 +1,9 @@
use std::convert::TryInto;
use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget};
use crate::field::extension_field::FieldExtension; use crate::field::extension_field::FieldExtension;
use crate::field::extension_field::{Extendable, OEF}; use crate::field::extension_field::{Extendable, OEF};
use crate::field::field_types::{Field, PrimeField, RichField};
use crate::gates::arithmetic_extension::ArithmeticExtensionGate;
use crate::gates::multiplication_extension::MulExtensionGate;
use crate::field::field_types::{Field, PrimeField}; use crate::field::field_types::{Field, PrimeField};
use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gates::arithmetic::ArithmeticExtensionGate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::generator::{GeneratedValues, SimpleGenerator};
@ -12,33 +13,6 @@ use crate::plonk::circuit_builder::CircuitBuilder;
use crate::util::bits_u64; use crate::util::bits_u64;
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> { impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Finds the last available arithmetic gate with the given constants or add one if there aren't any.
/// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index
/// `g` and the gate's `i`-th operation is available.
fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) {
let (gate, i) = self
.free_arithmetic
.get(&(const_0, const_1))
.copied()
.unwrap_or_else(|| {
let gate = self.add_gate(
ArithmeticExtensionGate::new_from_config(&self.config),
vec![const_0, const_1],
);
(gate, 0)
});
// Update `free_arithmetic` with new values.
if i < ArithmeticExtensionGate::<D>::num_ops(&self.config) - 1 {
self.free_arithmetic
.insert((const_0, const_1), (gate, i + 1));
} else {
self.free_arithmetic.remove(&(const_0, const_1));
}
(gate, i)
}
pub fn arithmetic_extension( pub fn arithmetic_extension(
&mut self, &mut self,
const_0: F, const_0: F,
@ -59,7 +33,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
} }
// See if we've already computed the same operation. // See if we've already computed the same operation.
let operation = ArithmeticOperation { let operation = ExtensionArithmeticOperation {
const_0, const_0,
const_1, const_1,
multiplicand_0, multiplicand_0,
@ -70,15 +44,21 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
return result; return result;
} }
let result = if self.target_as_constant_ext(addend) == Some(F::Extension::ZERO) {
// If the addend is zero, we use a multiplication gate.
self.compute_mul_extension_operation(operation)
} else {
// Otherwise, we use an arithmetic gate.
self.compute_arithmetic_extension_operation(operation)
};
// Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot. // Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot.
let result = self.add_arithmetic_extension_operation(operation);
self.arithmetic_results.insert(operation, result); self.arithmetic_results.insert(operation, result);
result result
} }
fn add_arithmetic_extension_operation( fn compute_arithmetic_extension_operation(
&mut self, &mut self,
operation: ArithmeticOperation<F, D>, operation: ExtensionArithmeticOperation<F, D>,
) -> ExtensionTarget<D> { ) -> ExtensionTarget<D> {
let (gate, i) = self.find_arithmetic_gate(operation.const_0, operation.const_1); let (gate, i) = self.find_arithmetic_gate(operation.const_0, operation.const_1);
let wires_multiplicand_0 = ExtensionTarget::from_range( let wires_multiplicand_0 = ExtensionTarget::from_range(
@ -99,6 +79,22 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_ith_output(i)) ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_ith_output(i))
} }
fn compute_mul_extension_operation(
&mut self,
operation: ExtensionArithmeticOperation<F, D>,
) -> ExtensionTarget<D> {
let (gate, i) = self.find_mul_gate(operation.const_0);
let wires_multiplicand_0 =
ExtensionTarget::from_range(gate, MulExtensionGate::<D>::wires_ith_multiplicand_0(i));
let wires_multiplicand_1 =
ExtensionTarget::from_range(gate, MulExtensionGate::<D>::wires_ith_multiplicand_1(i));
self.connect_extension(operation.multiplicand_0, wires_multiplicand_0);
self.connect_extension(operation.multiplicand_1, wires_multiplicand_1);
ExtensionTarget::from_range(gate, MulExtensionGate::<D>::wires_ith_output(i))
}
/// Checks for special cases where the value of /// Checks for special cases where the value of
/// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend` /// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`
/// can be determined without adding an `ArithmeticGate`. /// can be determined without adding an `ArithmeticGate`.
@ -302,11 +298,11 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Multiply `n` `ExtensionTarget`s. /// Multiply `n` `ExtensionTarget`s.
pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> { pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> {
let mut product = self.one_extension(); terms
for &term in terms { .iter()
product = self.mul_extension(product, term); .copied()
} .reduce(|acc, t| self.mul_extension(acc, t))
product .unwrap_or_else(|| self.one_extension())
} }
/// Like `mul_add`, but for `ExtensionTarget`s. /// Like `mul_add`, but for `ExtensionTarget`s.
@ -321,14 +317,14 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Like `add_const`, but for `ExtensionTarget`s. /// Like `add_const`, but for `ExtensionTarget`s.
pub fn add_const_extension(&mut self, x: ExtensionTarget<D>, c: F) -> ExtensionTarget<D> { pub fn add_const_extension(&mut self, x: ExtensionTarget<D>, c: F) -> ExtensionTarget<D> {
let one = self.one_extension(); let c = self.constant_extension(c.into());
self.arithmetic_extension(F::ONE, c, one, x, one) self.add_extension(x, c)
} }
/// Like `mul_const`, but for `ExtensionTarget`s. /// Like `mul_const`, but for `ExtensionTarget`s.
pub fn mul_const_extension(&mut self, c: F, x: ExtensionTarget<D>) -> ExtensionTarget<D> { pub fn mul_const_extension(&mut self, c: F, x: ExtensionTarget<D>) -> ExtensionTarget<D> {
let zero = self.zero_extension(); let c = self.constant_extension(c.into());
self.mul_const_add_extension(c, x, zero) self.mul_extension(c, x)
} }
/// Like `mul_const_add`, but for `ExtensionTarget`s. /// Like `mul_const_add`, but for `ExtensionTarget`s.
@ -338,8 +334,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
x: ExtensionTarget<D>, x: ExtensionTarget<D>,
y: ExtensionTarget<D>, y: ExtensionTarget<D>,
) -> ExtensionTarget<D> { ) -> ExtensionTarget<D> {
let one = self.one_extension(); let c = self.constant_extension(c.into());
self.arithmetic_extension(c, F::ONE, x, one, y) self.mul_add_extension(c, x, y)
} }
/// Like `mul_add`, but for `ExtensionTarget`s. /// Like `mul_add`, but for `ExtensionTarget`s.
@ -544,9 +540,9 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
} }
} }
/// Represents an arithmetic operation in the circuit. Used to memoize results. /// Represents an extension arithmetic operation in the circuit. Used to memoize results.
#[derive(Copy, Clone, Eq, PartialEq, Hash)] #[derive(Copy, Clone, Eq, PartialEq, Hash)]
pub(crate) struct ArithmeticOperation<F: PrimeField + Extendable<D>, const D: usize> { pub(crate) struct ExtensionArithmeticOperation<F: PrimeField + Extendable<D>, const D: usize> {
const_0: F, const_0: F,
const_1: F, const_1: F,
multiplicand_0: ExtensionTarget<D>, multiplicand_0: ExtensionTarget<D>,
@ -556,11 +552,11 @@ pub(crate) struct ArithmeticOperation<F: PrimeField + Extendable<D>, const D: us
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::convert::TryInto;
use anyhow::Result; use anyhow::Result;
use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::algebra::ExtensionAlgebra;
use crate::field::extension_field::quartic::QuarticExtension;
use crate::field::extension_field::target::ExtensionAlgebraTarget;
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::iop::witness::{PartialWitness, Witness}; use crate::iop::witness::{PartialWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_builder::CircuitBuilder;
@ -623,9 +619,7 @@ mod tests {
let yt = builder.constant_extension(y); let yt = builder.constant_extension(y);
let zt = builder.constant_extension(z); let zt = builder.constant_extension(z);
let comp_zt = builder.div_extension(xt, yt); let comp_zt = builder.div_extension(xt, yt);
let comp_zt_unsafe = builder.div_extension(xt, yt);
builder.connect_extension(zt, comp_zt); builder.connect_extension(zt, comp_zt);
builder.connect_extension(zt, comp_zt_unsafe);
let data = builder.build::<C>(); let data = builder.build::<C>();
let proof = data.prove(pw)?; let proof = data.prove(pw)?;
@ -642,23 +636,29 @@ mod tests {
let config = CircuitConfig::standard_recursion_config(); let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new(); let mut pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config); let mut builder = CircuitBuilder::<F, D>::new(config);
let x = FF::rand_vec(D); let xt =
let y = FF::rand_vec(D); ExtensionAlgebraTarget(builder.add_virtual_extension_targets(D).try_into().unwrap());
let xa = ExtensionAlgebra(x.try_into().unwrap()); let yt =
let ya = ExtensionAlgebra(y.try_into().unwrap()); ExtensionAlgebraTarget(builder.add_virtual_extension_targets(D).try_into().unwrap());
let za = xa * ya; let zt =
ExtensionAlgebraTarget(builder.add_virtual_extension_targets(D).try_into().unwrap());
let xt = builder.constant_ext_algebra(xa);
let yt = builder.constant_ext_algebra(ya);
let zt = builder.constant_ext_algebra(za);
let comp_zt = builder.mul_ext_algebra(xt, yt); let comp_zt = builder.mul_ext_algebra(xt, yt);
for i in 0..D { for i in 0..D {
builder.connect_extension(zt.0[i], comp_zt.0[i]); builder.connect_extension(zt.0[i], comp_zt.0[i]);
} }
let x = ExtensionAlgebra::<FF, D>(FF::rand_arr());
let y = ExtensionAlgebra::<FF, D>(FF::rand_arr());
let z = x * y;
for i in 0..D {
pw.set_extension_target(xt.0[i], x.0[i]);
pw.set_extension_target(yt.0[i], y.0[i]);
pw.set_extension_target(zt.0[i], z.0[i]);
}
let data = builder.build::<C>(); let data = builder.build::<C>();
let proof = data.prove(pw)?; let proof = data.prove(pw)?;

View File

@ -0,0 +1,154 @@
use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField;
use crate::gates::arithmetic_u32::U32ArithmeticGate;
use crate::gates::subtraction_u32::U32SubtractionGate;
use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;
#[derive(Clone, Copy, Debug)]
pub struct U32Target(pub Target);
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn add_virtual_u32_target(&mut self) -> U32Target {
U32Target(self.add_virtual_target())
}
pub fn add_virtual_u32_targets(&mut self, n: usize) -> Vec<U32Target> {
self.add_virtual_targets(n)
.into_iter()
.map(U32Target)
.collect()
}
pub fn zero_u32(&mut self) -> U32Target {
U32Target(self.zero())
}
pub fn one_u32(&mut self) -> U32Target {
U32Target(self.one())
}
pub fn connect_u32(&mut self, x: U32Target, y: U32Target) {
self.connect(x.0, y.0)
}
pub fn assert_zero_u32(&mut self, x: U32Target) {
self.assert_zero(x.0)
}
/// Checks for special cases where the value of
/// `x * y + z`
/// can be determined without adding a `U32ArithmeticGate`.
pub fn arithmetic_u32_special_cases(
&mut self,
x: U32Target,
y: U32Target,
z: U32Target,
) -> Option<(U32Target, U32Target)> {
let x_const = self.target_as_constant(x.0);
let y_const = self.target_as_constant(y.0);
let z_const = self.target_as_constant(z.0);
// If both terms are constant, return their (constant) sum.
let first_term_const = if let (Some(xx), Some(yy)) = (x_const, y_const) {
Some(xx * yy)
} else {
None
};
if let (Some(a), Some(b)) = (first_term_const, z_const) {
let sum = (a + b).to_canonical_u64();
let (low, high) = (sum as u32, (sum >> 32) as u32);
return Some((self.constant_u32(low), self.constant_u32(high)));
}
None
}
// Returns x * y + z.
pub fn mul_add_u32(
&mut self,
x: U32Target,
y: U32Target,
z: U32Target,
) -> (U32Target, U32Target) {
if let Some(result) = self.arithmetic_u32_special_cases(x, y, z) {
return result;
}
let gate = U32ArithmeticGate::<F, D>::new_from_config(&self.config);
let (gate_index, copy) = self.find_u32_arithmetic_gate();
self.connect(
Target::wire(gate_index, gate.wire_ith_multiplicand_0(copy)),
x.0,
);
self.connect(
Target::wire(gate_index, gate.wire_ith_multiplicand_1(copy)),
y.0,
);
self.connect(Target::wire(gate_index, gate.wire_ith_addend(copy)), z.0);
let output_low = U32Target(Target::wire(
gate_index,
gate.wire_ith_output_low_half(copy),
));
let output_high = U32Target(Target::wire(
gate_index,
gate.wire_ith_output_high_half(copy),
));
(output_low, output_high)
}
pub fn add_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) {
let one = self.one_u32();
self.mul_add_u32(a, one, b)
}
pub fn add_many_u32(&mut self, to_add: &[U32Target]) -> (U32Target, U32Target) {
match to_add.len() {
0 => (self.zero_u32(), self.zero_u32()),
1 => (to_add[0], self.zero_u32()),
2 => self.add_u32(to_add[0], to_add[1]),
_ => {
let (mut low, mut carry) = self.add_u32(to_add[0], to_add[1]);
for i in 2..to_add.len() {
let (new_low, new_carry) = self.add_u32(to_add[i], low);
let (combined_carry, _zero) = self.add_u32(carry, new_carry);
low = new_low;
carry = combined_carry;
}
(low, carry)
}
}
}
pub fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) {
let zero = self.zero_u32();
self.mul_add_u32(a, b, zero)
}
// Returns x - y - borrow, as a pair (result, borrow), where borrow is 0 or 1 depending on whether borrowing from the next digit is required (iff y + borrow > x).
pub fn sub_u32(
&mut self,
x: U32Target,
y: U32Target,
borrow: U32Target,
) -> (U32Target, U32Target) {
let gate = U32SubtractionGate::<F, D>::new_from_config(&self.config);
let (gate_index, copy) = self.find_u32_subtraction_gate();
self.connect(Target::wire(gate_index, gate.wire_ith_input_x(copy)), x.0);
self.connect(Target::wire(gate_index, gate.wire_ith_input_y(copy)), y.0);
self.connect(
Target::wire(gate_index, gate.wire_ith_input_borrow(copy)),
borrow.0,
);
let output_result = U32Target(Target::wire(gate_index, gate.wire_ith_output_result(copy)));
let output_borrow = U32Target(Target::wire(gate_index, gate.wire_ith_output_borrow(copy)));
(output_result, output_borrow)
}
}

395
src/gadgets/biguint.rs Normal file
View File

@ -0,0 +1,395 @@
use std::marker::PhantomData;
use num::{BigUint, Integer};
use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField;
use crate::gadgets::arithmetic_u32::U32Target;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::{BoolTarget, Target};
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
#[derive(Clone, Debug)]
pub struct BigUintTarget {
pub limbs: Vec<U32Target>,
}
impl BigUintTarget {
pub fn num_limbs(&self) -> usize {
self.limbs.len()
}
pub fn get_limb(&self, i: usize) -> U32Target {
self.limbs[i]
}
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget {
let limb_values = value.to_u32_digits();
let limbs = limb_values.iter().map(|&l| self.constant_u32(l)).collect();
BigUintTarget { limbs }
}
pub fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) {
let min_limbs = lhs.num_limbs().min(rhs.num_limbs());
for i in 0..min_limbs {
self.connect_u32(lhs.get_limb(i), rhs.get_limb(i));
}
for i in min_limbs..lhs.num_limbs() {
self.assert_zero_u32(lhs.get_limb(i));
}
for i in min_limbs..rhs.num_limbs() {
self.assert_zero_u32(rhs.get_limb(i));
}
}
pub fn pad_biguints(
&mut self,
a: &BigUintTarget,
b: &BigUintTarget,
) -> (BigUintTarget, BigUintTarget) {
if a.num_limbs() > b.num_limbs() {
let mut padded_b = b.clone();
for _ in b.num_limbs()..a.num_limbs() {
padded_b.limbs.push(self.zero_u32());
}
(a.clone(), padded_b)
} else {
let mut padded_a = a.clone();
for _ in a.num_limbs()..b.num_limbs() {
padded_a.limbs.push(self.zero_u32());
}
(padded_a, b.clone())
}
}
pub fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget {
let (a, b) = self.pad_biguints(a, b);
self.list_le_u32(a.limbs, b.limbs)
}
pub fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget {
let limbs = (0..num_limbs)
.map(|_| self.add_virtual_u32_target())
.collect();
BigUintTarget { limbs }
}
// Add two `BigUintTarget`s.
pub fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let num_limbs = a.num_limbs().max(b.num_limbs());
let mut combined_limbs = vec![];
let mut carry = self.zero_u32();
for i in 0..num_limbs {
let a_limb = (i < a.num_limbs())
.then(|| a.limbs[i])
.unwrap_or_else(|| self.zero_u32());
let b_limb = (i < b.num_limbs())
.then(|| b.limbs[i])
.unwrap_or_else(|| self.zero_u32());
let (new_limb, new_carry) = self.add_many_u32(&[carry, a_limb, b_limb]);
carry = new_carry;
combined_limbs.push(new_limb);
}
combined_limbs.push(carry);
BigUintTarget {
limbs: combined_limbs,
}
}
// Subtract two `BigUintTarget`s. We assume that the first is larger than the second.
pub fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let (a, b) = self.pad_biguints(a, b);
let num_limbs = a.limbs.len();
let mut result_limbs = vec![];
let mut borrow = self.zero_u32();
for i in 0..num_limbs {
let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow);
result_limbs.push(result);
borrow = new_borrow;
}
// Borrow should be zero here.
BigUintTarget {
limbs: result_limbs,
}
}
pub fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let total_limbs = a.limbs.len() + b.limbs.len();
let mut to_add = vec![vec![]; total_limbs];
for i in 0..a.limbs.len() {
for j in 0..b.limbs.len() {
let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]);
to_add[i + j].push(product);
to_add[i + j + 1].push(carry);
}
}
let mut combined_limbs = vec![];
let mut carry = self.zero_u32();
for summands in &mut to_add {
summands.push(carry);
let (new_result, new_carry) = self.add_many_u32(summands);
combined_limbs.push(new_result);
carry = new_carry;
}
combined_limbs.push(carry);
BigUintTarget {
limbs: combined_limbs,
}
}
// Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function).
pub fn mul_add_biguint(
&mut self,
x: &BigUintTarget,
y: &BigUintTarget,
z: &BigUintTarget,
) -> BigUintTarget {
let prod = self.mul_biguint(x, y);
self.add_biguint(&prod, z)
}
pub fn div_rem_biguint(
&mut self,
a: &BigUintTarget,
b: &BigUintTarget,
) -> (BigUintTarget, BigUintTarget) {
let a_len = a.limbs.len();
let b_len = b.limbs.len();
let div_num_limbs = if b_len > a_len + 1 {
0
} else {
a_len - b_len + 1
};
let div = self.add_virtual_biguint_target(div_num_limbs);
let rem = self.add_virtual_biguint_target(b_len);
self.add_simple_generator(BigUintDivRemGenerator::<F, D> {
a: a.clone(),
b: b.clone(),
div: div.clone(),
rem: rem.clone(),
_phantom: PhantomData,
});
let div_b = self.mul_biguint(&div, b);
let div_b_plus_rem = self.add_biguint(&div_b, &rem);
self.connect_biguint(a, &div_b_plus_rem);
let cmp_rem_b = self.cmp_biguint(&rem, b);
self.assert_one(cmp_rem_b.target);
(div, rem)
}
pub fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let (div, _rem) = self.div_rem_biguint(a, b);
div
}
pub fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let (_div, rem) = self.div_rem_biguint(a, b);
rem
}
}
#[derive(Debug)]
struct BigUintDivRemGenerator<F: RichField + Extendable<D>, const D: usize> {
a: BigUintTarget,
b: BigUintTarget,
div: BigUintTarget,
rem: BigUintTarget,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
for BigUintDivRemGenerator<F, D>
{
fn dependencies(&self) -> Vec<Target> {
self.a
.limbs
.iter()
.chain(&self.b.limbs)
.map(|&l| l.0)
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let a = witness.get_biguint_target(self.a.clone());
let b = witness.get_biguint_target(self.b.clone());
let (div, rem) = a.div_rem(&b);
out_buffer.set_biguint_target(self.div.clone(), div);
out_buffer.set_biguint_target(self.rem.clone(), rem);
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use num::{BigUint, FromPrimitive, Integer};
use rand::Rng;
use crate::iop::witness::Witness;
use crate::{
field::goldilocks_field::GoldilocksField,
iop::witness::PartialWitness,
plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, verifier::verify},
};
#[test]
fn test_biguint_add() -> Result<()> {
let mut rng = rand::thread_rng();
let x_value = BigUint::from_u128(rng.gen()).unwrap();
let y_value = BigUint::from_u128(rng.gen()).unwrap();
let expected_z_value = &x_value + &y_value;
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let mut pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len());
let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len());
let z = builder.add_biguint(&x, &y);
let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len());
builder.connect_biguint(&z, &expected_z);
pw.set_biguint_target(&x, &x_value);
pw.set_biguint_target(&y, &y_value);
pw.set_biguint_target(&expected_z, &expected_z_value);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_biguint_sub() -> Result<()> {
let mut rng = rand::thread_rng();
let mut x_value = BigUint::from_u128(rng.gen()).unwrap();
let mut y_value = BigUint::from_u128(rng.gen()).unwrap();
if y_value > x_value {
(x_value, y_value) = (y_value, x_value);
}
let expected_z_value = &x_value - &y_value;
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_biguint(&x_value);
let y = builder.constant_biguint(&y_value);
let z = builder.sub_biguint(&x, &y);
let expected_z = builder.constant_biguint(&expected_z_value);
builder.connect_biguint(&z, &expected_z);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_biguint_mul() -> Result<()> {
let mut rng = rand::thread_rng();
let x_value = BigUint::from_u128(rng.gen()).unwrap();
let y_value = BigUint::from_u128(rng.gen()).unwrap();
let expected_z_value = &x_value * &y_value;
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let mut pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len());
let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len());
let z = builder.mul_biguint(&x, &y);
let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len());
builder.connect_biguint(&z, &expected_z);
pw.set_biguint_target(&x, &x_value);
pw.set_biguint_target(&y, &y_value);
pw.set_biguint_target(&expected_z, &expected_z_value);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_biguint_cmp() -> Result<()> {
let mut rng = rand::thread_rng();
let x_value = BigUint::from_u128(rng.gen()).unwrap();
let y_value = BigUint::from_u128(rng.gen()).unwrap();
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_biguint(&x_value);
let y = builder.constant_biguint(&y_value);
let cmp = builder.cmp_biguint(&x, &y);
let expected_cmp = builder.constant_bool(x_value <= y_value);
builder.connect(cmp.target, expected_cmp.target);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_biguint_div_rem() -> Result<()> {
let mut rng = rand::thread_rng();
let mut x_value = BigUint::from_u128(rng.gen()).unwrap();
let mut y_value = BigUint::from_u128(rng.gen()).unwrap();
if y_value > x_value {
(x_value, y_value) = (y_value, x_value);
}
let (expected_div_value, expected_rem_value) = x_value.div_rem(&y_value);
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_biguint(&x_value);
let y = builder.constant_biguint(&y_value);
let (div, rem) = builder.div_rem_biguint(&x, &y);
let expected_div = builder.constant_biguint(&expected_div_value);
let expected_rem = builder.constant_biguint(&expected_rem_value);
builder.connect_biguint(&div, &expected_div);
builder.connect_biguint(&rem, &expected_rem);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
}

368
src/gadgets/curve.rs Normal file
View File

@ -0,0 +1,368 @@
use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar};
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, RichField};
use crate::gadgets::nonnative::NonNativeTarget;
use crate::plonk::circuit_builder::CircuitBuilder;
/// A Target representing an affine point on the curve `C`. We use incomplete arithmetic for efficiency,
/// so we assume these points are not zero.
#[derive(Clone, Debug)]
pub struct AffinePointTarget<C: Curve> {
pub x: NonNativeTarget<C::BaseField>,
pub y: NonNativeTarget<C::BaseField>,
}
impl<C: Curve> AffinePointTarget<C> {
pub fn to_vec(&self) -> Vec<NonNativeTarget<C::BaseField>> {
vec![self.x.clone(), self.y.clone()]
}
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn constant_affine_point<C: Curve>(
&mut self,
point: AffinePoint<C>,
) -> AffinePointTarget<C> {
debug_assert!(!point.zero);
AffinePointTarget {
x: self.constant_nonnative(point.x),
y: self.constant_nonnative(point.y),
}
}
pub fn connect_affine_point<C: Curve>(
&mut self,
lhs: &AffinePointTarget<C>,
rhs: &AffinePointTarget<C>,
) {
self.connect_nonnative(&lhs.x, &rhs.x);
self.connect_nonnative(&lhs.y, &rhs.y);
}
pub fn add_virtual_affine_point_target<C: Curve>(&mut self) -> AffinePointTarget<C> {
let x = self.add_virtual_nonnative_target();
let y = self.add_virtual_nonnative_target();
AffinePointTarget { x, y }
}
pub fn curve_assert_valid<C: Curve>(&mut self, p: &AffinePointTarget<C>) {
let a = self.constant_nonnative(C::A);
let b = self.constant_nonnative(C::B);
let y_squared = self.mul_nonnative(&p.y, &p.y);
let x_squared = self.mul_nonnative(&p.x, &p.x);
let x_cubed = self.mul_nonnative(&x_squared, &p.x);
let a_x = self.mul_nonnative(&a, &p.x);
let a_x_plus_b = self.add_nonnative(&a_x, &b);
let rhs = self.add_nonnative(&x_cubed, &a_x_plus_b);
self.connect_nonnative(&y_squared, &rhs);
}
pub fn curve_neg<C: Curve>(&mut self, p: &AffinePointTarget<C>) -> AffinePointTarget<C> {
let neg_y = self.neg_nonnative(&p.y);
AffinePointTarget {
x: p.x.clone(),
y: neg_y,
}
}
pub fn curve_double<C: Curve>(&mut self, p: &AffinePointTarget<C>) -> AffinePointTarget<C> {
let AffinePointTarget { x, y } = p;
let double_y = self.add_nonnative(y, y);
let inv_double_y = self.inv_nonnative(&double_y);
let x_squared = self.mul_nonnative(x, x);
let double_x_squared = self.add_nonnative(&x_squared, &x_squared);
let triple_x_squared = self.add_nonnative(&double_x_squared, &x_squared);
let a = self.constant_nonnative(C::A);
let triple_xx_a = self.add_nonnative(&triple_x_squared, &a);
let lambda = self.mul_nonnative(&triple_xx_a, &inv_double_y);
let lambda_squared = self.mul_nonnative(&lambda, &lambda);
let x_double = self.add_nonnative(x, x);
let x3 = self.sub_nonnative(&lambda_squared, &x_double);
let x_diff = self.sub_nonnative(x, &x3);
let lambda_x_diff = self.mul_nonnative(&lambda, &x_diff);
let y3 = self.sub_nonnative(&lambda_x_diff, y);
AffinePointTarget { x: x3, y: y3 }
}
// Add two points, which are assumed to be non-equal.
pub fn curve_add<C: Curve>(
&mut self,
p1: &AffinePointTarget<C>,
p2: &AffinePointTarget<C>,
) -> AffinePointTarget<C> {
let AffinePointTarget { x: x1, y: y1 } = p1;
let AffinePointTarget { x: x2, y: y2 } = p2;
let u = self.sub_nonnative(y2, y1);
let uu = self.mul_nonnative(&u, &u);
let v = self.sub_nonnative(x2, x1);
let vv = self.mul_nonnative(&v, &v);
let vvv = self.mul_nonnative(&v, &vv);
let r = self.mul_nonnative(&vv, x1);
let diff = self.sub_nonnative(&uu, &vvv);
let r2 = self.add_nonnative(&r, &r);
let a = self.sub_nonnative(&diff, &r2);
let x3 = self.mul_nonnative(&v, &a);
let r_a = self.sub_nonnative(&r, &a);
let y3_first = self.mul_nonnative(&u, &r_a);
let y3_second = self.mul_nonnative(&vvv, y1);
let y3 = self.sub_nonnative(&y3_first, &y3_second);
let z3_inv = self.inv_nonnative(&vvv);
let x3_norm = self.mul_nonnative(&x3, &z3_inv);
let y3_norm = self.mul_nonnative(&y3, &z3_inv);
AffinePointTarget {
x: x3_norm,
y: y3_norm,
}
}
pub fn curve_scalar_mul<C: Curve>(
&mut self,
p: &AffinePointTarget<C>,
n: &NonNativeTarget<C::ScalarField>,
) -> AffinePointTarget<C> {
let one = self.constant_nonnative(C::BaseField::ONE);
let bits = self.split_nonnative_to_bits(n);
let bits_as_base: Vec<NonNativeTarget<C::BaseField>> =
bits.iter().map(|b| self.bool_to_nonnative(b)).collect();
let rando = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine();
let randot = self.constant_affine_point(rando);
// Result starts at `rando`, which is later subtracted, because we don't support arithmetic with the zero point.
let mut result = self.add_virtual_affine_point_target();
self.connect_affine_point(&randot, &result);
let mut two_i_times_p = self.add_virtual_affine_point_target();
self.connect_affine_point(p, &two_i_times_p);
for bit in bits_as_base.iter() {
let not_bit = self.sub_nonnative(&one, bit);
let result_plus_2_i_p = self.curve_add(&result, &two_i_times_p);
let new_x_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p.x);
let new_x_if_not_bit = self.mul_nonnative(&not_bit, &result.x);
let new_y_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p.y);
let new_y_if_not_bit = self.mul_nonnative(&not_bit, &result.y);
let new_x = self.add_nonnative(&new_x_if_bit, &new_x_if_not_bit);
let new_y = self.add_nonnative(&new_y_if_bit, &new_y_if_not_bit);
result = AffinePointTarget { x: new_x, y: new_y };
two_i_times_p = self.curve_double(&two_i_times_p);
}
// Subtract off result's intial value of `rando`.
let neg_r = self.curve_neg(&randot);
result = self.curve_add(&result, &neg_r);
result
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar};
use crate::curve::secp256k1::Secp256K1;
use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField;
use crate::field::secp256k1_base::Secp256K1Base;
use crate::field::secp256k1_scalar::Secp256K1Scalar;
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::verifier::verify;
#[test]
fn test_curve_point_is_valid() -> Result<()> {
type F = GoldilocksField;
const D: usize = 4;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let g = Secp256K1::GENERATOR_AFFINE;
let g_target = builder.constant_affine_point(g);
let neg_g_target = builder.curve_neg(&g_target);
builder.curve_assert_valid(&g_target);
builder.curve_assert_valid(&neg_g_target);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
#[should_panic]
fn test_curve_point_is_not_valid() {
type F = GoldilocksField;
const D: usize = 4;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let g = Secp256K1::GENERATOR_AFFINE;
let not_g = AffinePoint::<Secp256K1> {
x: g.x,
y: g.y + Secp256K1Base::ONE,
zero: g.zero,
};
let not_g_target = builder.constant_affine_point(not_g);
builder.curve_assert_valid(&not_g_target);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common).unwrap();
}
#[test]
fn test_curve_double() -> Result<()> {
type F = GoldilocksField;
const D: usize = 4;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let g = Secp256K1::GENERATOR_AFFINE;
let g_target = builder.constant_affine_point(g);
let neg_g_target = builder.curve_neg(&g_target);
let double_g = g.double();
let double_g_expected = builder.constant_affine_point(double_g);
builder.curve_assert_valid(&double_g_expected);
let double_neg_g = (-g).double();
let double_neg_g_expected = builder.constant_affine_point(double_neg_g);
builder.curve_assert_valid(&double_neg_g_expected);
let double_g_actual = builder.curve_double(&g_target);
let double_neg_g_actual = builder.curve_double(&neg_g_target);
builder.curve_assert_valid(&double_g_actual);
builder.curve_assert_valid(&double_neg_g_actual);
builder.connect_affine_point(&double_g_expected, &double_g_actual);
builder.connect_affine_point(&double_neg_g_expected, &double_neg_g_actual);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_curve_add() -> Result<()> {
type F = GoldilocksField;
const D: usize = 4;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let g = Secp256K1::GENERATOR_AFFINE;
let double_g = g.double();
let g_plus_2g = (g + double_g).to_affine();
let g_plus_2g_expected = builder.constant_affine_point(g_plus_2g);
builder.curve_assert_valid(&g_plus_2g_expected);
let g_target = builder.constant_affine_point(g);
let double_g_target = builder.curve_double(&g_target);
let g_plus_2g_actual = builder.curve_add(&g_target, &double_g_target);
builder.curve_assert_valid(&g_plus_2g_actual);
builder.connect_affine_point(&g_plus_2g_expected, &g_plus_2g_actual);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
#[ignore]
fn test_curve_mul() -> Result<()> {
type F = GoldilocksField;
const D: usize = 4;
let config = CircuitConfig {
num_routed_wires: 33,
..CircuitConfig::standard_recursion_config()
};
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let g = Secp256K1::GENERATOR_AFFINE;
let five = Secp256K1Scalar::from_canonical_usize(5);
let five_scalar = CurveScalar::<Secp256K1>(five);
let five_g = (five_scalar * g.to_projective()).to_affine();
let five_g_expected = builder.constant_affine_point(five_g);
builder.curve_assert_valid(&five_g_expected);
let g_target = builder.constant_affine_point(g);
let five_target = builder.constant_nonnative(five);
let five_g_actual = builder.curve_scalar_mul(&g_target, &five_target);
builder.curve_assert_valid(&five_g_actual);
builder.connect_affine_point(&five_g_expected, &five_g_actual);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
#[ignore]
fn test_curve_random() -> Result<()> {
type F = GoldilocksField;
const D: usize = 4;
let config = CircuitConfig {
num_routed_wires: 33,
..CircuitConfig::standard_recursion_config()
};
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let rando =
(CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine();
let randot = builder.constant_affine_point(rando);
let two_target = builder.constant_nonnative(Secp256K1Scalar::TWO);
let randot_doubled = builder.curve_double(&randot);
let randot_times_two = builder.curve_scalar_mul(&randot, &two_target);
builder.connect_affine_point(&randot_doubled, &randot_times_two);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
}

View File

@ -1,22 +1,94 @@
use std::ops::Range;
use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable; use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField;
use crate::gates::gate::Gate;
use crate::gates::interpolation::InterpolationGate; use crate::gates::interpolation::InterpolationGate;
use crate::iop::target::Target; use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_builder::CircuitBuilder;
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> { /// Trait for gates which interpolate a polynomial, whose points are a (base field) coset of the multiplicative subgroup
/// with the given size, and whose values are extension field elements, given by input wires.
/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point.
pub(crate) trait InterpolationGate<F: Extendable<D>, const D: usize>:
Gate<F, D> + Copy
{
fn new(subgroup_bits: usize) -> Self;
fn num_points(&self) -> usize;
/// Wire index of the coset shift.
fn wire_shift(&self) -> usize {
0
}
fn start_values(&self) -> usize {
1
}
/// Wire indices of the `i`th interpolant value.
fn wires_value(&self, i: usize) -> Range<usize> {
debug_assert!(i < self.num_points());
let start = self.start_values() + i * D;
start..start + D
}
fn start_evaluation_point(&self) -> usize {
self.start_values() + self.num_points() * D
}
/// Wire indices of the point to evaluate the interpolant at.
fn wires_evaluation_point(&self) -> Range<usize> {
let start = self.start_evaluation_point();
start..start + D
}
fn start_evaluation_value(&self) -> usize {
self.start_evaluation_point() + D
}
/// Wire indices of the interpolated value.
fn wires_evaluation_value(&self) -> Range<usize> {
let start = self.start_evaluation_value();
start..start + D
}
fn start_coeffs(&self) -> usize {
self.start_evaluation_value() + D
}
/// The number of routed wires required in the typical usage of this gate, where the points to
/// interpolate, the evaluation point, and the corresponding value are all routed.
fn num_routed_wires(&self) -> usize {
self.start_coeffs()
}
/// Wire indices of the interpolant's `i`th coefficient.
fn wires_coeff(&self, i: usize) -> Range<usize> {
debug_assert!(i < self.num_points());
let start = self.start_coeffs() + i * D;
start..start + D
}
fn end_coeffs(&self) -> usize {
self.start_coeffs() + D * self.num_points()
}
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Interpolates a polynomial, whose points are a coset of the multiplicative subgroup with the /// Interpolates a polynomial, whose points are a coset of the multiplicative subgroup with the
/// given size, and whose values are given. Returns the evaluation of the interpolant at /// given size, and whose values are given. Returns the evaluation of the interpolant at
/// `evaluation_point`. /// `evaluation_point`.
pub fn interpolate_coset( pub(crate) fn interpolate_coset<G: InterpolationGate<F, D>>(
&mut self, &mut self,
subgroup_bits: usize, subgroup_bits: usize,
coset_shift: Target, coset_shift: Target,
values: &[ExtensionTarget<D>], values: &[ExtensionTarget<D>],
evaluation_point: ExtensionTarget<D>, evaluation_point: ExtensionTarget<D>,
) -> ExtensionTarget<D> { ) -> ExtensionTarget<D> {
let gate = InterpolationGate::new(subgroup_bits); let gate = G::new(subgroup_bits);
let gate_index = self.add_gate(gate.clone(), vec![]); let gate_index = self.add_gate(gate, vec![]);
self.connect(coset_shift, Target::wire(gate_index, gate.wire_shift())); self.connect(coset_shift, Target::wire(gate_index, gate.wire_shift()));
for (i, &v) in values.iter().enumerate() { for (i, &v) in values.iter().enumerate() {
self.connect_extension( self.connect_extension(
@ -37,6 +109,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
mod tests { mod tests {
use anyhow::Result; use anyhow::Result;
use crate::field::extension_field::quadratic::QuadraticExtension;
use crate::field::extension_field::FieldExtension; use crate::field::extension_field::FieldExtension;
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::field::interpolation::interpolant; use crate::field::interpolation::interpolant;
@ -83,9 +156,21 @@ mod tests {
let zt = builder.constant_extension(z); let zt = builder.constant_extension(z);
let eval = builder.interpolate_coset(subgroup_bits, coset_shift_target, &value_targets, zt); let eval_hd = builder.interpolate_coset::<HighDegreeInterpolationGate<F, D>>(
subgroup_bits,
coset_shift_target,
&value_targets,
zt,
);
let eval_ld = builder.interpolate_coset::<LowDegreeInterpolationGate<F, D>>(
subgroup_bits,
coset_shift_target,
&value_targets,
zt,
);
let true_eval_target = builder.constant_extension(true_eval); let true_eval_target = builder.constant_extension(true_eval);
builder.connect_extension(eval, true_eval_target); builder.connect_extension(eval_hd, true_eval_target);
builder.connect_extension(eval_ld, true_eval_target);
let data = builder.build::<C>(); let data = builder.build::<C>();
let proof = data.prove(pw)?; let proof = data.prove(pw)?;

View File

@ -1,8 +1,13 @@
pub mod arithmetic; pub mod arithmetic;
pub mod arithmetic_extension; pub mod arithmetic_extension;
pub mod arithmetic_u32;
pub mod biguint;
pub mod curve;
pub mod hash; pub mod hash;
pub mod insert; pub mod insert;
pub mod interpolation; pub mod interpolation;
pub mod multiple_comparison;
pub mod nonnative;
pub mod permutation; pub mod permutation;
pub mod polynomial; pub mod polynomial;
pub mod random_access; pub mod random_access;

View File

@ -0,0 +1,138 @@
use super::arithmetic_u32::U32Target;
use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField;
use crate::gates::comparison::ComparisonGate;
use crate::iop::target::{BoolTarget, Target};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::util::ceil_div_usize;
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Returns true if a is less than or equal to b, considered as base-`2^num_bits` limbs of a large value.
/// This range-checks its inputs.
pub fn list_le(&mut self, a: Vec<Target>, b: Vec<Target>, num_bits: usize) -> BoolTarget {
assert_eq!(
a.len(),
b.len(),
"Comparison must be between same number of inputs and outputs"
);
let n = a.len();
let chunk_bits = 2;
let num_chunks = ceil_div_usize(num_bits, chunk_bits);
let one = self.one();
let mut result = one;
for i in 0..n {
let a_le_b_gate = ComparisonGate::new(num_bits, num_chunks);
let a_le_b_gate_index = self.add_gate(a_le_b_gate.clone(), vec![]);
self.connect(
Target::wire(a_le_b_gate_index, a_le_b_gate.wire_first_input()),
a[i],
);
self.connect(
Target::wire(a_le_b_gate_index, a_le_b_gate.wire_second_input()),
b[i],
);
let a_le_b_result = Target::wire(a_le_b_gate_index, a_le_b_gate.wire_result_bool());
let b_le_a_gate = ComparisonGate::new(num_bits, num_chunks);
let b_le_a_gate_index = self.add_gate(b_le_a_gate.clone(), vec![]);
self.connect(
Target::wire(b_le_a_gate_index, b_le_a_gate.wire_first_input()),
b[i],
);
self.connect(
Target::wire(b_le_a_gate_index, b_le_a_gate.wire_second_input()),
a[i],
);
let b_le_a_result = Target::wire(b_le_a_gate_index, b_le_a_gate.wire_result_bool());
let these_limbs_equal = self.mul(a_le_b_result, b_le_a_result);
let these_limbs_less_than = self.sub(one, b_le_a_result);
result = self.mul_add(these_limbs_equal, result, these_limbs_less_than);
}
// `result` being boolean is an invariant, maintained because its new value is always
// `x * result + y`, where `x` and `y` are booleans that are not simultaneously true.
BoolTarget::new_unsafe(result)
}
/// Helper function for comparing, specifically, lists of `U32Target`s.
pub fn list_le_u32(&mut self, a: Vec<U32Target>, b: Vec<U32Target>) -> BoolTarget {
let a_targets = a.iter().map(|&t| t.0).collect();
let b_targets = b.iter().map(|&t| t.0).collect();
self.list_le(a_targets, b_targets, 32)
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use num::BigUint;
use rand::Rng;
use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField;
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::verifier::verify;
fn test_list_le(size: usize, num_bits: usize) -> Result<()> {
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let mut rng = rand::thread_rng();
let lst1: Vec<u64> = (0..size)
.map(|_| rng.gen_range(0..(1 << num_bits)))
.collect();
let lst2: Vec<u64> = (0..size)
.map(|_| rng.gen_range(0..(1 << num_bits)))
.collect();
let a_biguint = BigUint::from_slice(
&lst1
.iter()
.flat_map(|&x| [x as u32, (x >> 32) as u32])
.collect::<Vec<_>>(),
);
let b_biguint = BigUint::from_slice(
&lst2
.iter()
.flat_map(|&x| [x as u32, (x >> 32) as u32])
.collect::<Vec<_>>(),
);
let a = lst1
.iter()
.map(|&x| builder.constant(F::from_canonical_u64(x)))
.collect();
let b = lst2
.iter()
.map(|&x| builder.constant(F::from_canonical_u64(x)))
.collect();
let result = builder.list_le(a, b, num_bits);
let expected_result = builder.constant_bool(a_biguint <= b_biguint);
builder.connect(result.target, expected_result.target);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_multiple_comparison() -> Result<()> {
for size in [1, 3, 6] {
for num_bits in [20, 32, 40, 44] {
test_list_le(size, num_bits).unwrap();
}
}
Ok(())
}
}

342
src/gadgets/nonnative.rs Normal file
View File

@ -0,0 +1,342 @@
use std::marker::PhantomData;
use num::{BigUint, Zero};
use crate::field::field_types::RichField;
use crate::field::{extension_field::Extendable, field_types::Field};
use crate::gadgets::arithmetic_u32::U32Target;
use crate::gadgets::biguint::BigUintTarget;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::{BoolTarget, Target};
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::util::ceil_div_usize;
#[derive(Clone, Debug)]
pub struct NonNativeTarget<FF: Field> {
pub(crate) value: BigUintTarget,
_phantom: PhantomData<FF>,
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
fn num_nonnative_limbs<FF: Field>() -> usize {
ceil_div_usize(FF::BITS, 32)
}
pub fn biguint_to_nonnative<FF: Field>(&mut self, x: &BigUintTarget) -> NonNativeTarget<FF> {
NonNativeTarget {
value: x.clone(),
_phantom: PhantomData,
}
}
pub fn nonnative_to_biguint<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> BigUintTarget {
x.value.clone()
}
pub fn constant_nonnative<FF: Field>(&mut self, x: FF) -> NonNativeTarget<FF> {
let x_biguint = self.constant_biguint(&x.to_biguint());
self.biguint_to_nonnative(&x_biguint)
}
// Assert that two NonNativeTarget's, both assumed to be in reduced form, are equal.
pub fn connect_nonnative<FF: Field>(
&mut self,
lhs: &NonNativeTarget<FF>,
rhs: &NonNativeTarget<FF>,
) {
self.connect_biguint(&lhs.value, &rhs.value);
}
pub fn add_virtual_nonnative_target<FF: Field>(&mut self) -> NonNativeTarget<FF> {
let num_limbs = Self::num_nonnative_limbs::<FF>();
let value = self.add_virtual_biguint_target(num_limbs);
NonNativeTarget {
value,
_phantom: PhantomData,
}
}
// Add two `NonNativeTarget`s.
pub fn add_nonnative<FF: Field>(
&mut self,
a: &NonNativeTarget<FF>,
b: &NonNativeTarget<FF>,
) -> NonNativeTarget<FF> {
let result = self.add_biguint(&a.value, &b.value);
// TODO: reduce add result with only one conditional subtraction
self.reduce(&result)
}
// Subtract two `NonNativeTarget`s.
pub fn sub_nonnative<FF: Field>(
&mut self,
a: &NonNativeTarget<FF>,
b: &NonNativeTarget<FF>,
) -> NonNativeTarget<FF> {
let order = self.constant_biguint(&FF::order());
let a_plus_order = self.add_biguint(&order, &a.value);
let result = self.sub_biguint(&a_plus_order, &b.value);
// TODO: reduce sub result with only one conditional addition?
self.reduce(&result)
}
pub fn mul_nonnative<FF: Field>(
&mut self,
a: &NonNativeTarget<FF>,
b: &NonNativeTarget<FF>,
) -> NonNativeTarget<FF> {
let result = self.mul_biguint(&a.value, &b.value);
self.reduce(&result)
}
pub fn neg_nonnative<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> NonNativeTarget<FF> {
let zero_target = self.constant_biguint(&BigUint::zero());
let zero_ff = self.biguint_to_nonnative(&zero_target);
self.sub_nonnative(&zero_ff, x)
}
pub fn inv_nonnative<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> NonNativeTarget<FF> {
let num_limbs = x.value.num_limbs();
let inv_biguint = self.add_virtual_biguint_target(num_limbs);
let inv = NonNativeTarget::<FF> {
value: inv_biguint,
_phantom: PhantomData,
};
self.add_simple_generator(NonNativeInverseGenerator::<F, D, FF> {
x: x.clone(),
inv: inv.clone(),
_phantom: PhantomData,
});
let product = self.mul_nonnative(x, &inv);
let one = self.constant_nonnative(FF::ONE);
self.connect_nonnative(&product, &one);
inv
}
pub fn div_rem_nonnative<FF: Field>(
&mut self,
x: &NonNativeTarget<FF>,
y: &NonNativeTarget<FF>,
) -> (NonNativeTarget<FF>, NonNativeTarget<FF>) {
let x_biguint = self.nonnative_to_biguint(x);
let y_biguint = self.nonnative_to_biguint(y);
let (div_biguint, rem_biguint) = self.div_rem_biguint(&x_biguint, &y_biguint);
let div = self.biguint_to_nonnative(&div_biguint);
let rem = self.biguint_to_nonnative(&rem_biguint);
(div, rem)
}
/// Returns `x % |FF|` as a `NonNativeTarget`.
fn reduce<FF: Field>(&mut self, x: &BigUintTarget) -> NonNativeTarget<FF> {
let modulus = FF::order();
let order_target = self.constant_biguint(&modulus);
let value = self.rem_biguint(x, &order_target);
NonNativeTarget {
value,
_phantom: PhantomData,
}
}
#[allow(dead_code)]
fn reduce_nonnative<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> NonNativeTarget<FF> {
let x_biguint = self.nonnative_to_biguint(x);
self.reduce(&x_biguint)
}
pub fn bool_to_nonnative<FF: Field>(&mut self, b: &BoolTarget) -> NonNativeTarget<FF> {
let limbs = vec![U32Target(b.target)];
let value = BigUintTarget { limbs };
NonNativeTarget {
value,
_phantom: PhantomData,
}
}
// Split a nonnative field element to bits.
pub fn split_nonnative_to_bits<FF: Field>(
&mut self,
x: &NonNativeTarget<FF>,
) -> Vec<BoolTarget> {
let num_limbs = x.value.num_limbs();
let mut result = Vec::with_capacity(num_limbs * 32);
for i in 0..num_limbs {
let limb = x.value.get_limb(i);
let bit_targets = self.split_le_base::<2>(limb.0, 32);
let mut bits: Vec<_> = bit_targets
.iter()
.map(|&t| BoolTarget::new_unsafe(t))
.collect();
result.append(&mut bits);
}
result
}
}
#[derive(Debug)]
struct NonNativeInverseGenerator<F: RichField + Extendable<D>, const D: usize, FF: Field> {
x: NonNativeTarget<FF>,
inv: NonNativeTarget<FF>,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize, FF: Field> SimpleGenerator<F>
for NonNativeInverseGenerator<F, D, FF>
{
fn dependencies(&self) -> Vec<Target> {
self.x.value.limbs.iter().map(|&l| l.0).collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let x = witness.get_nonnative_target(self.x.clone());
let inv = x.inverse();
out_buffer.set_nonnative_target(self.inv.clone(), inv);
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField;
use crate::field::secp256k1_base::Secp256K1Base;
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::verifier::verify;
#[test]
fn test_nonnative_add() -> Result<()> {
type FF = Secp256K1Base;
let x_ff = FF::rand();
let y_ff = FF::rand();
let sum_ff = x_ff + y_ff;
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_nonnative(x_ff);
let y = builder.constant_nonnative(y_ff);
let sum = builder.add_nonnative(&x, &y);
let sum_expected = builder.constant_nonnative(sum_ff);
builder.connect_nonnative(&sum, &sum_expected);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_nonnative_sub() -> Result<()> {
type FF = Secp256K1Base;
let x_ff = FF::rand();
let mut y_ff = FF::rand();
while y_ff.to_biguint() > x_ff.to_biguint() {
y_ff = FF::rand();
}
let diff_ff = x_ff - y_ff;
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_nonnative(x_ff);
let y = builder.constant_nonnative(y_ff);
let diff = builder.sub_nonnative(&x, &y);
let diff_expected = builder.constant_nonnative(diff_ff);
builder.connect_nonnative(&diff, &diff_expected);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_nonnative_mul() -> Result<()> {
type FF = Secp256K1Base;
let x_ff = FF::rand();
let y_ff = FF::rand();
let product_ff = x_ff * y_ff;
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_nonnative(x_ff);
let y = builder.constant_nonnative(y_ff);
let product = builder.mul_nonnative(&x, &y);
let product_expected = builder.constant_nonnative(product_ff);
builder.connect_nonnative(&product, &product_expected);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_nonnative_neg() -> Result<()> {
type FF = Secp256K1Base;
let x_ff = FF::rand();
let neg_x_ff = -x_ff;
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_nonnative(x_ff);
let neg_x = builder.neg_nonnative(&x);
let neg_x_expected = builder.constant_nonnative(neg_x_ff);
builder.connect_nonnative(&neg_x, &neg_x_expected);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_nonnative_inv() -> Result<()> {
type FF = Secp256K1Base;
let x_ff = FF::rand();
let inv_x_ff = x_ff.inverse();
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_nonnative(x_ff);
let inv_x = builder.inv_nonnative(&x);
let inv_x_expected = builder.constant_nonnative(inv_x_ff);
builder.connect_nonnative(&inv_x, &inv_x_expected);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
}

View File

@ -2,7 +2,6 @@ use std::collections::BTreeMap;
use std::marker::PhantomData; use std::marker::PhantomData;
use crate::field::{extension_field::Extendable, field_types::Field}; use crate::field::{extension_field::Extendable, field_types::Field};
use crate::gates::switch::SwitchGate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::Target; use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness}; use crate::iop::witness::{PartitionWitness, Witness};
@ -34,7 +33,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.assert_permutation_2x2(a[0].clone(), a[1].clone(), b[0].clone(), b[1].clone()) self.assert_permutation_2x2(a[0].clone(), a[1].clone(), b[0].clone(), b[1].clone())
} }
// For larger lists, we recursively use two smaller permutation networks. // For larger lists, we recursively use two smaller permutation networks.
//_ => self.assert_permutation_recursive(a, b)
_ => self.assert_permutation_recursive(a, b), _ => self.assert_permutation_recursive(a, b),
} }
} }
@ -72,22 +70,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let chunk_size = a1.len(); let chunk_size = a1.len();
if self.current_switch_gates.len() < chunk_size { let (gate, gate_index, next_copy) = self.find_switch_gate(chunk_size);
self.current_switch_gates
.extend(vec![None; chunk_size - self.current_switch_gates.len()]);
}
let (gate, gate_index, mut next_copy) =
match self.current_switch_gates[chunk_size - 1].clone() {
None => {
let gate = SwitchGate::<F, D>::new_from_config(&self.config, chunk_size);
let gate_index = self.add_gate(gate.clone(), vec![]);
(gate, gate_index, 0)
}
Some((gate, idx, next_copy)) => (gate, idx, next_copy),
};
let num_copies = gate.num_copies;
let mut c = Vec::new(); let mut c = Vec::new();
let mut d = Vec::new(); let mut d = Vec::new();
@ -112,13 +95,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let switch = Target::wire(gate_index, gate.wire_switch_bool(next_copy)); let switch = Target::wire(gate_index, gate.wire_switch_bool(next_copy));
next_copy += 1;
if next_copy == num_copies {
self.current_switch_gates[chunk_size - 1] = None;
} else {
self.current_switch_gates[chunk_size - 1] = Some((gate, gate_index, next_copy));
}
(switch, c, d) (switch, c, d)
} }
@ -402,7 +378,7 @@ mod tests {
let pw = PartialWitness::new(); let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config); let mut builder = CircuitBuilder::<F, D>::new(config);
let lst: Vec<F> = (0..size * 2).map(|n| F::from_canonical_usize(n)).collect(); let lst: Vec<F> = (0..size * 2).map(F::from_canonical_usize).collect();
let a: Vec<Vec<Target>> = lst[..] let a: Vec<Vec<Target>> = lst[..]
.chunks(2) .chunks(2)
.map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])]) .map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])])

View File

@ -63,4 +63,21 @@ impl<const D: usize> PolynomialCoeffsExtAlgebraTarget<D> {
} }
acc acc
} }
/// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1.
pub fn eval_with_powers<F>(
&self,
builder: &mut CircuitBuilder<F, D>,
powers: &[ExtensionAlgebraTarget<D>],
) -> ExtensionAlgebraTarget<D>
where
F: RichField + Extendable<D>,
{
debug_assert_eq!(self.0.len(), powers.len() + 1);
let acc = self.0[0];
self.0[1..]
.iter()
.zip(powers)
.fold(acc, |acc, (&x, &c)| builder.mul_add_ext_algebra(c, x, acc))
}
} }

View File

@ -3,49 +3,20 @@ use crate::field::extension_field::Extendable;
use crate::gates::random_access::RandomAccessGate; use crate::gates::random_access::RandomAccessGate;
use crate::iop::target::Target; use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_builder::CircuitBuilder;
use crate::util::log2_strict;
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> { impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Finds the last available random access gate with the given `vec_size` or add one if there aren't any.
/// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index
/// `g` and the gate's `i`-th random access is available.
fn find_random_access_gate(&mut self, vec_size: usize) -> (usize, usize) {
let (gate, i) = self
.free_random_access
.get(&vec_size)
.copied()
.unwrap_or_else(|| {
let gate = self.add_gate(
RandomAccessGate::new_from_config(&self.config, vec_size),
vec![],
);
(gate, 0)
});
// Update `free_random_access` with new values.
if i < RandomAccessGate::<F, D>::max_num_copies(
self.config.num_routed_wires,
self.config.num_wires,
vec_size,
) - 1
{
self.free_random_access.insert(vec_size, (gate, i + 1));
} else {
self.free_random_access.remove(&vec_size);
}
(gate, i)
}
/// Checks that a `Target` matches a vector at a non-deterministic index. /// Checks that a `Target` matches a vector at a non-deterministic index.
/// Note: `access_index` is not range-checked. /// Note: `access_index` is not range-checked.
pub fn random_access(&mut self, access_index: Target, claimed_element: Target, v: Vec<Target>) { pub fn random_access(&mut self, access_index: Target, claimed_element: Target, v: Vec<Target>) {
let vec_size = v.len(); let vec_size = v.len();
let bits = log2_strict(vec_size);
debug_assert!(vec_size > 0); debug_assert!(vec_size > 0);
if vec_size == 1 { if vec_size == 1 {
return self.connect(claimed_element, v[0]); return self.connect(claimed_element, v[0]);
} }
let (gate_index, copy) = self.find_random_access_gate(vec_size); let (gate_index, copy) = self.find_random_access_gate(bits);
let dummy_gate = RandomAccessGate::<F, D>::new_from_config(&self.config, vec_size); let dummy_gate = RandomAccessGate::<F, D>::new_from_config(&self.config, bits);
v.iter().enumerate().for_each(|(i, &val)| { v.iter().enumerate().for_each(|(i, &val)| {
self.connect( self.connect(

View File

@ -3,6 +3,8 @@ use std::marker::PhantomData;
use itertools::izip; use itertools::izip;
use crate::field::extension_field::Extendable; use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, RichField};
use crate::gates::assert_le::AssertLessThanGate;
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::gates::comparison::ComparisonGate; use crate::gates::comparison::ComparisonGate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::generator::{GeneratedValues, SimpleGenerator};
@ -40,9 +42,9 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.assert_permutation(a_chunks, b_chunks); self.assert_permutation(a_chunks, b_chunks);
} }
/// Add a ComparisonGate to assert that `lhs` is less than `rhs`, where their values are at most `bits` bits. /// Add an AssertLessThanGate to assert that `lhs` is less than `rhs`, where their values are at most `bits` bits.
pub fn assert_le(&mut self, lhs: Target, rhs: Target, bits: usize, num_chunks: usize) { pub fn assert_le(&mut self, lhs: Target, rhs: Target, bits: usize, num_chunks: usize) {
let gate = ComparisonGate::new(bits, num_chunks); let gate = AssertLessThanGate::new(bits, num_chunks);
let gate_index = self.add_gate(gate.clone(), vec![]); let gate_index = self.add_gate(gate.clone(), vec![]);
self.connect(Target::wire(gate_index, gate.wire_first_input()), lhs); self.connect(Target::wire(gate_index, gate.wire_first_input()), lhs);
@ -126,8 +128,7 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for MemoryOpSortGenera
fn dependencies(&self) -> Vec<Target> { fn dependencies(&self) -> Vec<Target> {
self.input_ops self.input_ops
.iter() .iter()
.map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value]) .flat_map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value])
.flatten()
.collect() .collect()
} }
@ -223,7 +224,7 @@ mod tests {
izip!(is_write_vals, address_vals, timestamp_vals, value_vals) izip!(is_write_vals, address_vals, timestamp_vals, value_vals)
.zip(combined_vals_u64) .zip(combined_vals_u64)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
input_ops_and_keys.sort_by_key(|(_, val)| val.clone()); input_ops_and_keys.sort_by_key(|(_, val)| *val);
let input_ops_sorted: Vec<_> = input_ops_and_keys.iter().map(|(x, _)| x).collect(); let input_ops_sorted: Vec<_> = input_ops_and_keys.iter().map(|(x, _)| x).collect();
let output_ops = let output_ops =

View File

@ -1,5 +1,7 @@
use std::borrow::Borrow; use std::borrow::Borrow;
use itertools::Itertools;
use crate::field::extension_field::Extendable; use crate::field::extension_field::Extendable;
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::gates::base_sum::BaseSumGate; use crate::gates::base_sum::BaseSumGate;
@ -27,23 +29,25 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Takes an iterator of bits `(b_i)` and returns `sum b_i * 2^i`, i.e., /// Takes an iterator of bits `(b_i)` and returns `sum b_i * 2^i`, i.e.,
/// the number with little-endian bit representation given by `bits`. /// the number with little-endian bit representation given by `bits`.
pub(crate) fn le_sum( pub(crate) fn le_sum(&mut self, bits: impl Iterator<Item = impl Borrow<BoolTarget>>) -> Target {
&mut self, let bits = bits.map(|b| *b.borrow()).collect_vec();
bits: impl ExactSizeIterator<Item = impl Borrow<BoolTarget>> + Clone,
) -> Target {
let num_bits = bits.len(); let num_bits = bits.len();
if num_bits == 0 { if num_bits == 0 {
return self.zero(); return self.zero();
} else if num_bits == 1 {
let mut bits = bits;
return bits.next().unwrap().borrow().target;
} else if num_bits == 2 {
let two = self.two();
let mut bits = bits;
let b0 = bits.next().unwrap().borrow().target;
let b1 = bits.next().unwrap().borrow().target;
return self.mul_add(two, b1, b0);
} }
// Check if it's cheaper to just do this with arithmetic operations.
let arithmetic_ops = num_bits - 1;
if arithmetic_ops <= self.num_base_arithmetic_ops_per_gate() {
let two = self.two();
let mut rev_bits = bits.iter().rev();
let mut sum = rev_bits.next().unwrap().target;
for &bit in rev_bits {
sum = self.mul_add(two, sum, bit.target);
}
return sum;
}
debug_assert!( debug_assert!(
BaseSumGate::<2>::START_LIMBS + num_bits <= self.config.num_routed_wires, BaseSumGate::<2>::START_LIMBS + num_bits <= self.config.num_routed_wires,
"Not enough routed wires." "Not enough routed wires."
@ -51,10 +55,10 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let gate_type = BaseSumGate::<2>::new_from_config::<F>(&self.config); let gate_type = BaseSumGate::<2>::new_from_config::<F>(&self.config);
let gate_index = self.add_gate(gate_type, vec![]); let gate_index = self.add_gate(gate_type, vec![]);
for (limb, wire) in bits for (limb, wire) in bits
.clone() .iter()
.zip(BaseSumGate::<2>::START_LIMBS..BaseSumGate::<2>::START_LIMBS + num_bits) .zip(BaseSumGate::<2>::START_LIMBS..BaseSumGate::<2>::START_LIMBS + num_bits)
{ {
self.connect(limb.borrow().target, Target::wire(gate_index, wire)); self.connect(limb.target, Target::wire(gate_index, wire));
} }
for l in gate_type.limbs().skip(num_bits) { for l in gate_type.limbs().skip(num_bits) {
self.assert_zero(Target::wire(gate_index, l)); self.assert_zero(Target::wire(gate_index, l));
@ -62,7 +66,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.add_simple_generator(BaseSumGenerator::<2> { self.add_simple_generator(BaseSumGenerator::<2> {
gate_index, gate_index,
limbs: bits.map(|l| *l.borrow()).collect(), limbs: bits,
}); });
Target::wire(gate_index, BaseSumGate::<2>::WIRE_SUM) Target::wire(gate_index, BaseSumGate::<2>::WIRE_SUM)
@ -146,14 +150,14 @@ mod tests {
let pw = PartialWitness::new(); let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config); let mut builder = CircuitBuilder::<F, D>::new(config);
let n = thread_rng().gen_range(0..(1 << 10)); let n = thread_rng().gen_range(0..(1 << 30));
let x = builder.constant(F::from_canonical_usize(n)); let x = builder.constant(F::from_canonical_usize(n));
let zero = builder._false(); let zero = builder._false();
let one = builder._true(); let one = builder._true();
let y = builder.le_sum( let y = builder.le_sum(
(0..10) (0..30)
.scan(n, |acc, _| { .scan(n, |acc, _| {
let tmp = *acc % 2; let tmp = *acc % 2;
*acc /= 2; *acc /= 2;

View File

@ -24,8 +24,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let mut bits = Vec::with_capacity(num_bits); let mut bits = Vec::with_capacity(num_bits);
for &gate in &gates { for &gate in &gates {
let start_limbs = BaseSumGate::<2>::START_LIMBS; for limb_input in gate_type.limbs() {
for limb_input in start_limbs..start_limbs + gate_type.num_limbs {
// `new_unsafe` is safe here because BaseSumGate::<2> forces it to be in `{0, 1}`. // `new_unsafe` is safe here because BaseSumGate::<2> forces it to be in `{0, 1}`.
bits.push(BoolTarget::new_unsafe(Target::wire(gate, limb_input))); bits.push(BoolTarget::new_unsafe(Target::wire(gate, limb_input)));
} }
@ -35,10 +34,11 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
} }
let zero = self.zero(); let zero = self.zero();
let base = F::TWO.exp_u64(gate_type.num_limbs as u64);
let mut acc = zero; let mut acc = zero;
for &gate in gates.iter().rev() { for &gate in gates.iter().rev() {
let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM); let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM);
acc = self.mul_const_add(F::from_canonical_usize(1 << gate_type.num_limbs), acc, sum); acc = self.mul_const_add(base, acc, sum);
} }
self.connect(acc, integer); self.connect(acc, integer);
@ -96,11 +96,18 @@ impl<F: RichField> SimpleGenerator<F> for WireSplitGenerator {
for &gate in &self.gates { for &gate in &self.gates {
let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM); let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM);
out_buffer.set_target(
sum, // If num_limbs >= 64, we don't need to truncate since `integer_value` is already
F::from_canonical_u64(integer_value & ((1 << self.num_limbs) - 1)), // limited to 64 bits, and trying to do so would cause overflow. Hence the conditional.
); let mut truncated_value = integer_value;
if self.num_limbs < 64 {
truncated_value = integer_value & ((1 << self.num_limbs) - 1);
integer_value >>= self.num_limbs; integer_value >>= self.num_limbs;
} else {
integer_value = 0;
};
out_buffer.set_target(sum, F::from_canonical_u64(truncated_value));
} }
debug_assert_eq!( debug_assert_eq!(

View File

@ -0,0 +1,212 @@
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField;
use crate::gates::gate::Gate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// A gate which can perform a weighted multiply-add, i.e. `result = c0 x y + c1 z`. If the config
/// supports enough routed wires, it can support several such operations in one gate.
#[derive(Debug)]
pub struct ArithmeticGate {
/// Number of arithmetic operations performed by an arithmetic gate.
pub num_ops: usize,
}
impl ArithmeticGate {
pub fn new_from_config(config: &CircuitConfig) -> Self {
Self {
num_ops: Self::num_ops(config),
}
}
/// Determine the maximum number of operations that can fit in one gate for the given config.
pub(crate) fn num_ops(config: &CircuitConfig) -> usize {
let wires_per_op = 4;
config.num_routed_wires / wires_per_op
}
pub fn wire_ith_multiplicand_0(i: usize) -> usize {
4 * i
}
pub fn wire_ith_multiplicand_1(i: usize) -> usize {
4 * i + 1
}
pub fn wire_ith_addend(i: usize) -> usize {
4 * i + 2
}
pub fn wire_ith_output(i: usize) -> usize {
4 * i + 3
}
}
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticGate {
fn id(&self) -> String {
format!("{:?}", self)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let const_0 = vars.local_constants[0];
let const_1 = vars.local_constants[1];
let mut constraints = Vec::new();
for i in 0..self.num_ops {
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[Self::wire_ith_addend(i)];
let output = vars.local_wires[Self::wire_ith_output(i)];
let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1;
constraints.push(output - computed_output);
}
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let const_0 = vars.local_constants[0];
let const_1 = vars.local_constants[1];
let mut constraints = Vec::new();
for i in 0..self.num_ops {
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[Self::wire_ith_addend(i)];
let output = vars.local_wires[Self::wire_ith_output(i)];
let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1;
constraints.push(output - computed_output);
}
constraints
}
fn eval_unfiltered_recursively(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let const_0 = vars.local_constants[0];
let const_1 = vars.local_constants[1];
let mut constraints = Vec::new();
for i in 0..self.num_ops {
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[Self::wire_ith_addend(i)];
let output = vars.local_wires[Self::wire_ith_output(i)];
let computed_output = {
let scaled_mul =
builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]);
builder.mul_add_extension(const_1, addend, scaled_mul)
};
let diff = builder.sub_extension(output, computed_output);
constraints.push(diff);
}
constraints
}
fn generators(
&self,
gate_index: usize,
local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
(0..self.num_ops)
.map(|i| {
let g: Box<dyn WitnessGenerator<F>> = Box::new(
ArithmeticBaseGenerator {
gate_index,
const_0: local_constants[0],
const_1: local_constants[1],
i,
}
.adapter(),
);
g
})
.collect::<Vec<_>>()
}
fn num_wires(&self) -> usize {
self.num_ops * 4
}
fn num_constants(&self) -> usize {
2
}
fn degree(&self) -> usize {
3
}
fn num_constraints(&self) -> usize {
self.num_ops
}
}
#[derive(Clone, Debug)]
struct ArithmeticBaseGenerator<F: RichField + Extendable<D>, const D: usize> {
gate_index: usize,
const_0: F,
const_1: F,
i: usize,
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
for ArithmeticBaseGenerator<F, D>
{
fn dependencies(&self) -> Vec<Target> {
[
ArithmeticGate::wire_ith_multiplicand_0(self.i),
ArithmeticGate::wire_ith_multiplicand_1(self.i),
ArithmeticGate::wire_ith_addend(self.i),
]
.iter()
.map(|&i| Target::wire(self.gate_index, i))
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let get_wire =
|wire: usize| -> F { witness.get_target(Target::wire(self.gate_index, wire)) };
let multiplicand_0 = get_wire(ArithmeticGate::wire_ith_multiplicand_0(self.i));
let multiplicand_1 = get_wire(ArithmeticGate::wire_ith_multiplicand_1(self.i));
let addend = get_wire(ArithmeticGate::wire_ith_addend(self.i));
let output_target = Target::wire(self.gate_index, ArithmeticGate::wire_ith_output(self.i));
let computed_output =
multiplicand_0 * multiplicand_1 * self.const_0 + addend * self.const_1;
out_buffer.set_target(output_target, computed_output)
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::arithmetic_base::ArithmeticGate;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::plonk::circuit_data::CircuitConfig;
#[test]
fn low_degree() {
let gate = ArithmeticGate::new_from_config(&CircuitConfig::standard_recursion_config());
test_low_degree::<GoldilocksField, _, 4>(gate);
}
#[test]
fn eval_fns() -> Result<()> {
let gate = ArithmeticGate::new_from_config(&CircuitConfig::standard_recursion_config());
test_eval_fns::<GoldilocksField, _, 4>(gate)
}
}

View File

@ -11,7 +11,8 @@ use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// A gate which can a linear combination `c0*x*y+c1*z` twice with the same `x`. /// A gate which can perform a weighted multiply-add, i.e. `result = c0 x y + c1 z`. If the config
/// supports enough routed wires, it can support several such operations in one gate.
#[derive(Debug)] #[derive(Debug)]
pub struct ArithmeticExtensionGate<const D: usize> { pub struct ArithmeticExtensionGate<const D: usize> {
/// Number of arithmetic operations performed by an arithmetic gate. /// Number of arithmetic operations performed by an arithmetic gate.
@ -203,7 +204,7 @@ mod tests {
use anyhow::Result; use anyhow::Result;
use crate::field::goldilocks_field::GoldilocksField; use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gates::arithmetic_extension::ArithmeticExtensionGate;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};

View File

@ -11,37 +11,49 @@ use crate::iop::target::Target;
use crate::iop::wire::Wire; use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness}; use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// Number of arithmetic operations performed by an arithmetic gate.
pub const NUM_U32_ARITHMETIC_OPS: usize = 3;
/// A gate to perform a basic mul-add on 32-bit values (we assume they are range-checked beforehand). /// A gate to perform a basic mul-add on 32-bit values (we assume they are range-checked beforehand).
#[derive(Debug)] #[derive(Copy, Clone, Debug)]
pub struct U32ArithmeticGate<F: Extendable<D>, const D: usize> { pub struct U32ArithmeticGate<F: Extendable<D>, const D: usize> {
pub num_ops: usize,
_phantom: PhantomData<F>, _phantom: PhantomData<F>,
} }
impl<F: Extendable<D>, const D: usize> U32ArithmeticGate<F, D> { impl<F: Extendable<D>, const D: usize> U32ArithmeticGate<F, D> {
pub fn wire_ith_multiplicand_0(i: usize) -> usize { pub fn new_from_config(config: &CircuitConfig) -> Self {
debug_assert!(i < NUM_U32_ARITHMETIC_OPS); Self {
num_ops: Self::num_ops(config),
_phantom: PhantomData,
}
}
pub(crate) fn num_ops(config: &CircuitConfig) -> usize {
let wires_per_op = 5 + Self::num_limbs();
let routed_wires_per_op = 5;
(config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op)
}
pub fn wire_ith_multiplicand_0(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i 5 * i
} }
pub fn wire_ith_multiplicand_1(i: usize) -> usize { pub fn wire_ith_multiplicand_1(&self, i: usize) -> usize {
debug_assert!(i < NUM_U32_ARITHMETIC_OPS); debug_assert!(i < self.num_ops);
5 * i + 1 5 * i + 1
} }
pub fn wire_ith_addend(i: usize) -> usize { pub fn wire_ith_addend(&self, i: usize) -> usize {
debug_assert!(i < NUM_U32_ARITHMETIC_OPS); debug_assert!(i < self.num_ops);
5 * i + 2 5 * i + 2
} }
pub fn wire_ith_output_low_half(i: usize) -> usize { pub fn wire_ith_output_low_half(&self, i: usize) -> usize {
debug_assert!(i < NUM_U32_ARITHMETIC_OPS); debug_assert!(i < self.num_ops);
5 * i + 3 5 * i + 3
} }
pub fn wire_ith_output_high_half(i: usize) -> usize { pub fn wire_ith_output_high_half(&self, i: usize) -> usize {
debug_assert!(i < NUM_U32_ARITHMETIC_OPS); debug_assert!(i < self.num_ops);
5 * i + 4 5 * i + 4
} }
@ -52,10 +64,10 @@ impl<F: Extendable<D>, const D: usize> U32ArithmeticGate<F, D> {
64 / Self::limb_bits() 64 / Self::limb_bits()
} }
pub fn wire_ith_output_jth_limb(i: usize, j: usize) -> usize { pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize {
debug_assert!(i < NUM_U32_ARITHMETIC_OPS); debug_assert!(i < self.num_ops);
debug_assert!(j < Self::num_limbs()); debug_assert!(j < Self::num_limbs());
5 * NUM_U32_ARITHMETIC_OPS + Self::num_limbs() * i + j 5 * self.num_ops + Self::num_limbs() * i + j
} }
} }
@ -66,15 +78,15 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> { fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let mut constraints = Vec::with_capacity(self.num_constraints()); let mut constraints = Vec::with_capacity(self.num_constraints());
for i in 0..NUM_U32_ARITHMETIC_OPS { for i in 0..self.num_ops {
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[Self::wire_ith_addend(i)]; let addend = vars.local_wires[self.wire_ith_addend(i)];
let computed_output = multiplicand_0 * multiplicand_1 + addend; let computed_output = multiplicand_0 * multiplicand_1 + addend;
let output_low = vars.local_wires[Self::wire_ith_output_low_half(i)]; let output_low = vars.local_wires[self.wire_ith_output_low_half(i)];
let output_high = vars.local_wires[Self::wire_ith_output_high_half(i)]; let output_high = vars.local_wires[self.wire_ith_output_high_half(i)];
let base = F::Extension::from_canonical_u64(1 << 32u64); let base = F::Extension::from_canonical_u64(1 << 32u64);
let combined_output = output_high * base + output_low; let combined_output = output_high * base + output_low;
@ -86,7 +98,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
let midpoint = Self::num_limbs() / 2; let midpoint = Self::num_limbs() / 2;
let base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); let base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits());
for j in (0..Self::num_limbs()).rev() { for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits(); let max_limb = 1 << Self::limb_bits();
let product = (0..max_limb) let product = (0..max_limb)
.map(|x| this_limb - F::Extension::from_canonical_usize(x)) .map(|x| this_limb - F::Extension::from_canonical_usize(x))
@ -108,15 +120,15 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> { fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints()); let mut constraints = Vec::with_capacity(self.num_constraints());
for i in 0..NUM_U32_ARITHMETIC_OPS { for i in 0..self.num_ops {
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[Self::wire_ith_addend(i)]; let addend = vars.local_wires[self.wire_ith_addend(i)];
let computed_output = multiplicand_0 * multiplicand_1 + addend; let computed_output = multiplicand_0 * multiplicand_1 + addend;
let output_low = vars.local_wires[Self::wire_ith_output_low_half(i)]; let output_low = vars.local_wires[self.wire_ith_output_low_half(i)];
let output_high = vars.local_wires[Self::wire_ith_output_high_half(i)]; let output_high = vars.local_wires[self.wire_ith_output_high_half(i)];
let base = F::from_canonical_u64(1 << 32u64); let base = F::from_canonical_u64(1 << 32u64);
let combined_output = output_high * base + output_low; let combined_output = output_high * base + output_low;
@ -128,7 +140,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
let midpoint = Self::num_limbs() / 2; let midpoint = Self::num_limbs() / 2;
let base = F::from_canonical_u64(1u64 << Self::limb_bits()); let base = F::from_canonical_u64(1u64 << Self::limb_bits());
for j in (0..Self::num_limbs()).rev() { for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits(); let max_limb = 1 << Self::limb_bits();
let product = (0..max_limb) let product = (0..max_limb)
.map(|x| this_limb - F::from_canonical_usize(x)) .map(|x| this_limb - F::from_canonical_usize(x))
@ -155,15 +167,15 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
) -> Vec<ExtensionTarget<D>> { ) -> Vec<ExtensionTarget<D>> {
let mut constraints = Vec::with_capacity(self.num_constraints()); let mut constraints = Vec::with_capacity(self.num_constraints());
for i in 0..NUM_U32_ARITHMETIC_OPS { for i in 0..self.num_ops {
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)]; let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)]; let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[Self::wire_ith_addend(i)]; let addend = vars.local_wires[self.wire_ith_addend(i)];
let computed_output = builder.mul_add_extension(multiplicand_0, multiplicand_1, addend); let computed_output = builder.mul_add_extension(multiplicand_0, multiplicand_1, addend);
let output_low = vars.local_wires[Self::wire_ith_output_low_half(i)]; let output_low = vars.local_wires[self.wire_ith_output_low_half(i)];
let output_high = vars.local_wires[Self::wire_ith_output_high_half(i)]; let output_high = vars.local_wires[self.wire_ith_output_high_half(i)];
let base: F::Extension = F::from_canonical_u64(1 << 32u64).into(); let base: F::Extension = F::from_canonical_u64(1 << 32u64).into();
let base_target = builder.constant_extension(base); let base_target = builder.constant_extension(base);
@ -177,7 +189,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
let base = builder let base = builder
.constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); .constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits()));
for j in (0..Self::num_limbs()).rev() { for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)]; let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits(); let max_limb = 1 << Self::limb_bits();
let mut product = builder.one_extension(); let mut product = builder.one_extension();
@ -210,10 +222,11 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
gate_index: usize, gate_index: usize,
_local_constants: &[F], _local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> { ) -> Vec<Box<dyn WitnessGenerator<F>>> {
(0..NUM_U32_ARITHMETIC_OPS) (0..self.num_ops)
.map(|i| { .map(|i| {
let g: Box<dyn WitnessGenerator<F>> = Box::new( let g: Box<dyn WitnessGenerator<F>> = Box::new(
U32ArithmeticGenerator { U32ArithmeticGenerator {
gate: *self,
gate_index, gate_index,
i, i,
_phantom: PhantomData, _phantom: PhantomData,
@ -226,7 +239,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
} }
fn num_wires(&self) -> usize { fn num_wires(&self) -> usize {
NUM_U32_ARITHMETIC_OPS * (5 + Self::num_limbs()) self.num_ops * (5 + Self::num_limbs())
} }
fn num_constants(&self) -> usize { fn num_constants(&self) -> usize {
@ -238,12 +251,13 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
} }
fn num_constraints(&self) -> usize { fn num_constraints(&self) -> usize {
NUM_U32_ARITHMETIC_OPS * (3 + Self::num_limbs()) self.num_ops * (3 + Self::num_limbs())
} }
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
struct U32ArithmeticGenerator<F: Extendable<D>, const D: usize> { struct U32ArithmeticGenerator<F: Extendable<D>, const D: usize> {
gate: U32ArithmeticGate<F, D>,
gate_index: usize, gate_index: usize,
i: usize, i: usize,
_phantom: PhantomData<F>, _phantom: PhantomData<F>,
@ -253,17 +267,11 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for U32ArithmeticGener
fn dependencies(&self) -> Vec<Target> { fn dependencies(&self) -> Vec<Target> {
let local_target = |input| Target::wire(self.gate_index, input); let local_target = |input| Target::wire(self.gate_index, input);
let mut deps = Vec::with_capacity(3); vec![
deps.push(local_target( local_target(self.gate.wire_ith_multiplicand_0(self.i)),
U32ArithmeticGate::<F, D>::wire_ith_multiplicand_0(self.i), local_target(self.gate.wire_ith_multiplicand_1(self.i)),
)); local_target(self.gate.wire_ith_addend(self.i)),
deps.push(local_target( ]
U32ArithmeticGate::<F, D>::wire_ith_multiplicand_1(self.i),
));
deps.push(local_target(U32ArithmeticGate::<F, D>::wire_ith_addend(
self.i,
)));
deps
} }
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) { fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
@ -274,11 +282,9 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for U32ArithmeticGener
let get_local_wire = |input| witness.get_wire(local_wire(input)); let get_local_wire = |input| witness.get_wire(local_wire(input));
let multiplicand_0 = let multiplicand_0 = get_local_wire(self.gate.wire_ith_multiplicand_0(self.i));
get_local_wire(U32ArithmeticGate::<F, D>::wire_ith_multiplicand_0(self.i)); let multiplicand_1 = get_local_wire(self.gate.wire_ith_multiplicand_1(self.i));
let multiplicand_1 = let addend = get_local_wire(self.gate.wire_ith_addend(self.i));
get_local_wire(U32ArithmeticGate::<F, D>::wire_ith_multiplicand_1(self.i));
let addend = get_local_wire(U32ArithmeticGate::<F, D>::wire_ith_addend(self.i));
let output = multiplicand_0 * multiplicand_1 + addend; let output = multiplicand_0 * multiplicand_1 + addend;
let mut output_u64 = output.to_canonical_u64(); let mut output_u64 = output.to_canonical_u64();
@ -289,34 +295,25 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for U32ArithmeticGener
let output_high = F::from_canonical_u64(output_high_u64); let output_high = F::from_canonical_u64(output_high_u64);
let output_low = F::from_canonical_u64(output_low_u64); let output_low = F::from_canonical_u64(output_low_u64);
let output_high_wire = let output_high_wire = local_wire(self.gate.wire_ith_output_high_half(self.i));
local_wire(U32ArithmeticGate::<F, D>::wire_ith_output_high_half(self.i)); let output_low_wire = local_wire(self.gate.wire_ith_output_low_half(self.i));
let output_low_wire =
local_wire(U32ArithmeticGate::<F, D>::wire_ith_output_low_half(self.i));
out_buffer.set_wire(output_high_wire, output_high); out_buffer.set_wire(output_high_wire, output_high);
out_buffer.set_wire(output_low_wire, output_low); out_buffer.set_wire(output_low_wire, output_low);
let num_limbs = U32ArithmeticGate::<F, D>::num_limbs(); let num_limbs = U32ArithmeticGate::<F, D>::num_limbs();
let limb_base = 1 << U32ArithmeticGate::<F, D>::limb_bits(); let limb_base = 1 << U32ArithmeticGate::<F, D>::limb_bits();
let output_limbs_u64: Vec<_> = unfold((), move |_| { let output_limbs_u64 = unfold((), move |_| {
let ret = output_u64 % limb_base; let ret = output_u64 % limb_base;
output_u64 /= limb_base; output_u64 /= limb_base;
Some(ret) Some(ret)
}) })
.take(num_limbs) .take(num_limbs);
.collect(); let output_limbs_f = output_limbs_u64.map(F::from_canonical_u64);
let output_limbs_f: Vec<_> = output_limbs_u64
.iter()
.cloned()
.map(F::from_canonical_u64)
.collect();
for j in 0..num_limbs { for (j, output_limb) in output_limbs_f.enumerate() {
let wire = local_wire(U32ArithmeticGate::<F, D>::wire_ith_output_jth_limb( let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j));
self.i, j, out_buffer.set_wire(wire, output_limb);
));
out_buffer.set_wire(wire, output_limbs_f[j]);
} }
} }
} }
@ -330,7 +327,7 @@ mod tests {
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField; use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS}; use crate::gates::arithmetic_u32::U32ArithmeticGate;
use crate::gates::gate::Gate; use crate::gates::gate::Gate;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::hash::hash_types::HashOut; use crate::hash::hash_types::HashOut;
@ -340,16 +337,15 @@ mod tests {
#[test] #[test]
fn low_degree() { fn low_degree() {
test_low_degree::<GoldilocksField, _, 4>(U32ArithmeticGate::<GoldilocksField, 4> { test_low_degree::<GoldilocksField, _, 4>(U32ArithmeticGate::<GoldilocksField, 4> {
num_ops: 3,
_phantom: PhantomData, _phantom: PhantomData,
}) })
} }
#[test] #[test]
fn eval_fns() -> Result<()> { fn eval_fns() -> Result<()> {
const D: usize = 2; test_eval_fns::<GoldilocksField, _, 4>(U32ArithmeticGate::<GoldilocksField, 4> {
type C = PoseidonGoldilocksConfig; num_ops: 3,
type F = <C as GenericConfig<D>>::F;
test_eval_fns::<F, C, _, D>(U32ArithmeticGate::<F, D> {
_phantom: PhantomData, _phantom: PhantomData,
}) })
} }
@ -360,6 +356,7 @@ mod tests {
type C = PoseidonGoldilocksConfig; type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F; type F = <C as GenericConfig<D>>::F;
type FF = <C as GenericConfig<D>>::FE; type FF = <C as GenericConfig<D>>::FE;
const NUM_U32_ARITHMETIC_OPS: usize = 3;
fn get_wires( fn get_wires(
multiplicands_0: Vec<u64>, multiplicands_0: Vec<u64>,
@ -387,8 +384,7 @@ mod tests {
output /= limb_base; output /= limb_base;
} }
let mut output_limbs_f: Vec<_> = output_limbs let mut output_limbs_f: Vec<_> = output_limbs
.iter() .into_iter()
.cloned()
.map(F::from_canonical_u64) .map(F::from_canonical_u64)
.collect(); .collect();
@ -418,6 +414,7 @@ mod tests {
.collect(); .collect();
let gate = U32ArithmeticGate::<F, D> { let gate = U32ArithmeticGate::<F, D> {
num_ops: NUM_U32_ARITHMETIC_OPS,
_phantom: PhantomData, _phantom: PhantomData,
}; };

607
src/gates/assert_le.rs Normal file
View File

@ -0,0 +1,607 @@
use std::marker::PhantomData;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, PrimeField, RichField};
use crate::gates::gate::Gate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive};
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::util::{bits_u64, ceil_div_usize};
// TODO: replace/merge this gate with `ComparisonGate`.
/// A gate for checking that one value is less than or equal to another.
#[derive(Clone, Debug)]
pub struct AssertLessThanGate<F: PrimeField + Extendable<D>, const D: usize> {
pub(crate) num_bits: usize,
pub(crate) num_chunks: usize,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize> AssertLessThanGate<F, D> {
pub fn new(num_bits: usize, num_chunks: usize) -> Self {
debug_assert!(num_bits < bits_u64(F::ORDER));
Self {
num_bits,
num_chunks,
_phantom: PhantomData,
}
}
pub fn chunk_bits(&self) -> usize {
ceil_div_usize(self.num_bits, self.num_chunks)
}
pub fn wire_first_input(&self) -> usize {
0
}
pub fn wire_second_input(&self) -> usize {
1
}
pub fn wire_most_significant_diff(&self) -> usize {
2
}
pub fn wire_first_chunk_val(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks);
3 + chunk
}
pub fn wire_second_chunk_val(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks);
3 + self.num_chunks + chunk
}
pub fn wire_equality_dummy(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks);
3 + 2 * self.num_chunks + chunk
}
pub fn wire_chunks_equal(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks);
3 + 3 * self.num_chunks + chunk
}
pub fn wire_intermediate_value(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks);
3 + 4 * self.num_chunks + chunk
}
}
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for AssertLessThanGate<F, D> {
fn id(&self) -> String {
format!("{:?}<D={}>", self, D)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let mut constraints = Vec::with_capacity(self.num_constraints());
let first_input = vars.local_wires[self.wire_first_input()];
let second_input = vars.local_wires[self.wire_second_input()];
// Get chunks and assert that they match
let first_chunks: Vec<F::Extension> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_first_chunk_val(i)])
.collect();
let second_chunks: Vec<F::Extension> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_second_chunk_val(i)])
.collect();
let first_chunks_combined = reduce_with_powers(
&first_chunks,
F::Extension::from_canonical_usize(1 << self.chunk_bits()),
);
let second_chunks_combined = reduce_with_powers(
&second_chunks,
F::Extension::from_canonical_usize(1 << self.chunk_bits()),
);
constraints.push(first_chunks_combined - first_input);
constraints.push(second_chunks_combined - second_input);
let chunk_size = 1 << self.chunk_bits();
let mut most_significant_diff_so_far = F::Extension::ZERO;
for i in 0..self.num_chunks {
// Range-check the chunks to be less than `chunk_size`.
let first_product = (0..chunk_size)
.map(|x| first_chunks[i] - F::Extension::from_canonical_usize(x))
.product();
let second_product = (0..chunk_size)
.map(|x| second_chunks[i] - F::Extension::from_canonical_usize(x))
.product();
constraints.push(first_product);
constraints.push(second_product);
let difference = second_chunks[i] - first_chunks[i];
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
// Two constraints to assert that `chunks_equal` is valid.
constraints.push(difference * equality_dummy - (F::Extension::ONE - chunks_equal));
constraints.push(chunks_equal * difference);
// Update `most_significant_diff_so_far`.
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far);
most_significant_diff_so_far =
intermediate_value + (F::Extension::ONE - chunks_equal) * difference;
}
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
constraints.push(most_significant_diff - most_significant_diff_so_far);
// Range check `most_significant_diff` to be less than `chunk_size`.
let product = (0..chunk_size)
.map(|x| most_significant_diff - F::Extension::from_canonical_usize(x))
.product();
constraints.push(product);
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints());
let first_input = vars.local_wires[self.wire_first_input()];
let second_input = vars.local_wires[self.wire_second_input()];
// Get chunks and assert that they match
let first_chunks: Vec<F> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_first_chunk_val(i)])
.collect();
let second_chunks: Vec<F> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_second_chunk_val(i)])
.collect();
let first_chunks_combined = reduce_with_powers(
&first_chunks,
F::from_canonical_usize(1 << self.chunk_bits()),
);
let second_chunks_combined = reduce_with_powers(
&second_chunks,
F::from_canonical_usize(1 << self.chunk_bits()),
);
constraints.push(first_chunks_combined - first_input);
constraints.push(second_chunks_combined - second_input);
let chunk_size = 1 << self.chunk_bits();
let mut most_significant_diff_so_far = F::ZERO;
for i in 0..self.num_chunks {
// Range-check the chunks to be less than `chunk_size`.
let first_product = (0..chunk_size)
.map(|x| first_chunks[i] - F::from_canonical_usize(x))
.product();
let second_product = (0..chunk_size)
.map(|x| second_chunks[i] - F::from_canonical_usize(x))
.product();
constraints.push(first_product);
constraints.push(second_product);
let difference = second_chunks[i] - first_chunks[i];
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
// Two constraints to assert that `chunks_equal` is valid.
constraints.push(difference * equality_dummy - (F::ONE - chunks_equal));
constraints.push(chunks_equal * difference);
// Update `most_significant_diff_so_far`.
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far);
most_significant_diff_so_far =
intermediate_value + (F::ONE - chunks_equal) * difference;
}
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
constraints.push(most_significant_diff - most_significant_diff_so_far);
// Range check `most_significant_diff` to be less than `chunk_size`.
let product = (0..chunk_size)
.map(|x| most_significant_diff - F::from_canonical_usize(x))
.product();
constraints.push(product);
constraints
}
fn eval_unfiltered_recursively(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let mut constraints = Vec::with_capacity(self.num_constraints());
let first_input = vars.local_wires[self.wire_first_input()];
let second_input = vars.local_wires[self.wire_second_input()];
// Get chunks and assert that they match
let first_chunks: Vec<ExtensionTarget<D>> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_first_chunk_val(i)])
.collect();
let second_chunks: Vec<ExtensionTarget<D>> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_second_chunk_val(i)])
.collect();
let chunk_base = builder.constant(F::from_canonical_usize(1 << self.chunk_bits()));
let first_chunks_combined =
reduce_with_powers_ext_recursive(builder, &first_chunks, chunk_base);
let second_chunks_combined =
reduce_with_powers_ext_recursive(builder, &second_chunks, chunk_base);
constraints.push(builder.sub_extension(first_chunks_combined, first_input));
constraints.push(builder.sub_extension(second_chunks_combined, second_input));
let chunk_size = 1 << self.chunk_bits();
let mut most_significant_diff_so_far = builder.zero_extension();
let one = builder.one_extension();
// Find the chosen chunk.
for i in 0..self.num_chunks {
// Range-check the chunks to be less than `chunk_size`.
let mut first_product = one;
let mut second_product = one;
for x in 0..chunk_size {
let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x));
let first_diff = builder.sub_extension(first_chunks[i], x_f);
let second_diff = builder.sub_extension(second_chunks[i], x_f);
first_product = builder.mul_extension(first_product, first_diff);
second_product = builder.mul_extension(second_product, second_diff);
}
constraints.push(first_product);
constraints.push(second_product);
let difference = builder.sub_extension(second_chunks[i], first_chunks[i]);
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
// Two constraints to assert that `chunks_equal` is valid.
let diff_times_equal = builder.mul_extension(difference, equality_dummy);
let not_equal = builder.sub_extension(one, chunks_equal);
constraints.push(builder.sub_extension(diff_times_equal, not_equal));
constraints.push(builder.mul_extension(chunks_equal, difference));
// Update `most_significant_diff_so_far`.
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
let old_diff = builder.mul_extension(chunks_equal, most_significant_diff_so_far);
constraints.push(builder.sub_extension(intermediate_value, old_diff));
let not_equal = builder.sub_extension(one, chunks_equal);
let new_diff = builder.mul_extension(not_equal, difference);
most_significant_diff_so_far = builder.add_extension(intermediate_value, new_diff);
}
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
constraints
.push(builder.sub_extension(most_significant_diff, most_significant_diff_so_far));
// Range check `most_significant_diff` to be less than `chunk_size`.
let mut product = builder.one_extension();
for x in 0..chunk_size {
let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x));
let diff = builder.sub_extension(most_significant_diff, x_f);
product = builder.mul_extension(product, diff);
}
constraints.push(product);
constraints
}
fn generators(
&self,
gate_index: usize,
_local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
let gen = AssertLessThanGenerator::<F, D> {
gate_index,
gate: self.clone(),
};
vec![Box::new(gen.adapter())]
}
fn num_wires(&self) -> usize {
self.wire_intermediate_value(self.num_chunks - 1) + 1
}
fn num_constants(&self) -> usize {
0
}
fn degree(&self) -> usize {
1 << self.chunk_bits()
}
fn num_constraints(&self) -> usize {
4 + 5 * self.num_chunks
}
}
#[derive(Debug)]
struct AssertLessThanGenerator<F: RichField + Extendable<D>, const D: usize> {
gate_index: usize,
gate: AssertLessThanGate<F, D>,
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
for AssertLessThanGenerator<F, D>
{
fn dependencies(&self) -> Vec<Target> {
let local_target = |input| Target::wire(self.gate_index, input);
vec![
local_target(self.gate.wire_first_input()),
local_target(self.gate.wire_second_input()),
]
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let local_wire = |input| Wire {
gate: self.gate_index,
input,
};
let get_local_wire = |input| witness.get_wire(local_wire(input));
let first_input = get_local_wire(self.gate.wire_first_input());
let second_input = get_local_wire(self.gate.wire_second_input());
let first_input_u64 = first_input.to_canonical_u64();
let second_input_u64 = second_input.to_canonical_u64();
debug_assert!(first_input_u64 < second_input_u64);
let chunk_size = 1 << self.gate.chunk_bits();
let first_input_chunks: Vec<F> = (0..self.gate.num_chunks)
.scan(first_input_u64, |acc, _| {
let tmp = *acc % chunk_size;
*acc /= chunk_size;
Some(F::from_canonical_u64(tmp))
})
.collect();
let second_input_chunks: Vec<F> = (0..self.gate.num_chunks)
.scan(second_input_u64, |acc, _| {
let tmp = *acc % chunk_size;
*acc /= chunk_size;
Some(F::from_canonical_u64(tmp))
})
.collect();
let chunks_equal: Vec<F> = (0..self.gate.num_chunks)
.map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i]))
.collect();
let equality_dummies: Vec<F> = first_input_chunks
.iter()
.zip(second_input_chunks.iter())
.map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) })
.collect();
let mut most_significant_diff_so_far = F::ZERO;
let mut intermediate_values = Vec::new();
for i in 0..self.gate.num_chunks {
if first_input_chunks[i] != second_input_chunks[i] {
most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i];
intermediate_values.push(F::ZERO);
} else {
intermediate_values.push(most_significant_diff_so_far);
}
}
let most_significant_diff = most_significant_diff_so_far;
out_buffer.set_wire(
local_wire(self.gate.wire_most_significant_diff()),
most_significant_diff,
);
for i in 0..self.gate.num_chunks {
out_buffer.set_wire(
local_wire(self.gate.wire_first_chunk_val(i)),
first_input_chunks[i],
);
out_buffer.set_wire(
local_wire(self.gate.wire_second_chunk_val(i)),
second_input_chunks[i],
);
out_buffer.set_wire(
local_wire(self.gate.wire_equality_dummy(i)),
equality_dummies[i],
);
out_buffer.set_wire(local_wire(self.gate.wire_chunks_equal(i)), chunks_equal[i]);
out_buffer.set_wire(
local_wire(self.gate.wire_intermediate_value(i)),
intermediate_values[i],
);
}
}
}
#[cfg(test)]
mod tests {
use std::marker::PhantomData;
use anyhow::Result;
use rand::Rng;
use crate::field::extension_field::quartic::QuarticExtension;
use crate::field::field_types::{Field, PrimeField};
use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::assert_le::AssertLessThanGate;
use crate::gates::gate::Gate;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::hash::hash_types::HashOut;
use crate::plonk::vars::EvaluationVars;
#[test]
fn wire_indices() {
type AG = AssertLessThanGate<GoldilocksField, 4>;
let num_bits = 40;
let num_chunks = 5;
let gate = AG {
num_bits,
num_chunks,
_phantom: PhantomData,
};
assert_eq!(gate.wire_first_input(), 0);
assert_eq!(gate.wire_second_input(), 1);
assert_eq!(gate.wire_most_significant_diff(), 2);
assert_eq!(gate.wire_first_chunk_val(0), 3);
assert_eq!(gate.wire_first_chunk_val(4), 7);
assert_eq!(gate.wire_second_chunk_val(0), 8);
assert_eq!(gate.wire_second_chunk_val(4), 12);
assert_eq!(gate.wire_equality_dummy(0), 13);
assert_eq!(gate.wire_equality_dummy(4), 17);
assert_eq!(gate.wire_chunks_equal(0), 18);
assert_eq!(gate.wire_chunks_equal(4), 22);
assert_eq!(gate.wire_intermediate_value(0), 23);
assert_eq!(gate.wire_intermediate_value(4), 27);
}
#[test]
fn low_degree() {
let num_bits = 20;
let num_chunks = 4;
test_low_degree::<GoldilocksField, _, 4>(AssertLessThanGate::<_, 4>::new(
num_bits, num_chunks,
))
}
#[test]
fn eval_fns() -> Result<()> {
let num_bits = 20;
let num_chunks = 4;
test_eval_fns::<GoldilocksField, _, 4>(AssertLessThanGate::<_, 4>::new(
num_bits, num_chunks,
))
}
#[test]
fn test_gate_constraint() {
type F = GoldilocksField;
type FF = QuarticExtension<GoldilocksField>;
const D: usize = 4;
let num_bits = 40;
let num_chunks = 5;
let chunk_bits = num_bits / num_chunks;
// Returns the local wires for an AssertLessThanGate given the two inputs.
let get_wires = |first_input: F, second_input: F| -> Vec<FF> {
let mut v = Vec::new();
let first_input_u64 = first_input.to_canonical_u64();
let second_input_u64 = second_input.to_canonical_u64();
let chunk_size = 1 << chunk_bits;
let mut first_input_chunks: Vec<F> = (0..num_chunks)
.scan(first_input_u64, |acc, _| {
let tmp = *acc % chunk_size;
*acc /= chunk_size;
Some(F::from_canonical_u64(tmp))
})
.collect();
let mut second_input_chunks: Vec<F> = (0..num_chunks)
.scan(second_input_u64, |acc, _| {
let tmp = *acc % chunk_size;
*acc /= chunk_size;
Some(F::from_canonical_u64(tmp))
})
.collect();
let mut chunks_equal: Vec<F> = (0..num_chunks)
.map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i]))
.collect();
let mut equality_dummies: Vec<F> = first_input_chunks
.iter()
.zip(second_input_chunks.iter())
.map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) })
.collect();
let mut most_significant_diff_so_far = F::ZERO;
let mut intermediate_values = Vec::new();
for i in 0..num_chunks {
if first_input_chunks[i] != second_input_chunks[i] {
most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i];
intermediate_values.push(F::ZERO);
} else {
intermediate_values.push(most_significant_diff_so_far);
}
}
let most_significant_diff = most_significant_diff_so_far;
v.push(first_input);
v.push(second_input);
v.push(most_significant_diff);
v.append(&mut first_input_chunks);
v.append(&mut second_input_chunks);
v.append(&mut equality_dummies);
v.append(&mut chunks_equal);
v.append(&mut intermediate_values);
v.iter().map(|&x| x.into()).collect::<Vec<_>>()
};
let mut rng = rand::thread_rng();
let max: u64 = 1 << (num_bits - 1);
let first_input_u64 = rng.gen_range(0..max);
let second_input_u64 = {
let mut val = rng.gen_range(0..max);
while val < first_input_u64 {
val = rng.gen_range(0..max);
}
val
};
let first_input = F::from_canonical_u64(first_input_u64);
let second_input = F::from_canonical_u64(second_input_u64);
let less_than_gate = AssertLessThanGate::<F, D> {
num_bits,
num_chunks,
_phantom: PhantomData,
};
let less_than_vars = EvaluationVars {
local_constants: &[],
local_wires: &get_wires(first_input, second_input),
public_inputs_hash: &HashOut::rand(),
};
assert!(
less_than_gate
.eval_unfiltered(less_than_vars)
.iter()
.all(|x| x.is_zero()),
"Gate constraints are not satisfied."
);
let equal_gate = AssertLessThanGate::<F, D> {
num_bits,
num_chunks,
_phantom: PhantomData,
};
let equal_vars = EvaluationVars {
local_constants: &[],
local_wires: &get_wires(first_input, first_input),
public_inputs_hash: &HashOut::rand(),
};
assert!(
equal_gate
.eval_unfiltered(equal_vars)
.iter()
.all(|x| x.is_zero()),
"Gate constraints are not satisfied."
);
}
}

View File

@ -24,8 +24,7 @@ impl<const B: usize> BaseSumGate<B> {
} }
pub fn new_from_config<F: PrimeField>(config: &CircuitConfig) -> Self { pub fn new_from_config<F: PrimeField>(config: &CircuitConfig) -> Self {
let num_limbs = ((F::ORDER as f64).log(B as f64).floor() as usize) let num_limbs = F::BITS.min(config.num_routed_wires - Self::START_LIMBS);
.min(config.num_routed_wires - Self::START_LIMBS);
Self::new(num_limbs) Self::new(num_limbs)
} }

View File

@ -43,33 +43,42 @@ impl<F: Extendable<D>, const D: usize> ComparisonGate<F, D> {
1 1
} }
pub fn wire_most_significant_diff(&self) -> usize { pub fn wire_result_bool(&self) -> usize {
2 2
} }
pub fn wire_most_significant_diff(&self) -> usize {
3
}
pub fn wire_first_chunk_val(&self, chunk: usize) -> usize { pub fn wire_first_chunk_val(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks); debug_assert!(chunk < self.num_chunks);
3 + chunk 4 + chunk
} }
pub fn wire_second_chunk_val(&self, chunk: usize) -> usize { pub fn wire_second_chunk_val(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks); debug_assert!(chunk < self.num_chunks);
3 + self.num_chunks + chunk 4 + self.num_chunks + chunk
} }
pub fn wire_equality_dummy(&self, chunk: usize) -> usize { pub fn wire_equality_dummy(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks); debug_assert!(chunk < self.num_chunks);
3 + 2 * self.num_chunks + chunk 4 + 2 * self.num_chunks + chunk
} }
pub fn wire_chunks_equal(&self, chunk: usize) -> usize { pub fn wire_chunks_equal(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks); debug_assert!(chunk < self.num_chunks);
3 + 3 * self.num_chunks + chunk 4 + 3 * self.num_chunks + chunk
} }
pub fn wire_intermediate_value(&self, chunk: usize) -> usize { pub fn wire_intermediate_value(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks); debug_assert!(chunk < self.num_chunks);
3 + 4 * self.num_chunks + chunk 4 + 4 * self.num_chunks + chunk
}
/// The `bit_index`th bit of 2^n - 1 + most_significant_diff.
pub fn wire_most_significant_diff_bit(&self, bit_index: usize) -> usize {
4 + 5 * self.num_chunks + bit_index
} }
} }
@ -110,10 +119,10 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
for i in 0..self.num_chunks { for i in 0..self.num_chunks {
// Range-check the chunks to be less than `chunk_size`. // Range-check the chunks to be less than `chunk_size`.
let first_product = (0..chunk_size) let first_product: F::Extension = (0..chunk_size)
.map(|x| first_chunks[i] - F::Extension::from_canonical_usize(x)) .map(|x| first_chunks[i] - F::Extension::from_canonical_usize(x))
.product(); .product();
let second_product = (0..chunk_size) let second_product: F::Extension = (0..chunk_size)
.map(|x| second_chunks[i] - F::Extension::from_canonical_usize(x)) .map(|x| second_chunks[i] - F::Extension::from_canonical_usize(x))
.product(); .product();
constraints.push(first_product); constraints.push(first_product);
@ -137,11 +146,22 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
constraints.push(most_significant_diff - most_significant_diff_so_far); constraints.push(most_significant_diff - most_significant_diff_so_far);
// Range check `most_significant_diff` to be less than `chunk_size`. let most_significant_diff_bits: Vec<F::Extension> = (0..self.chunk_bits() + 1)
let product = (0..chunk_size) .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)])
.map(|x| most_significant_diff - F::Extension::from_canonical_usize(x)) .collect();
.product();
constraints.push(product); // Range-check the bits.
for &bit in &most_significant_diff_bits {
constraints.push(bit * (F::Extension::ONE - bit));
}
let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::Extension::TWO);
let two_n = F::Extension::from_canonical_u64(1 << self.chunk_bits());
constraints.push((two_n + most_significant_diff) - bits_combined);
// Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1.
let result_bool = vars.local_wires[self.wire_result_bool()];
constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]);
constraints constraints
} }
@ -178,10 +198,10 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
for i in 0..self.num_chunks { for i in 0..self.num_chunks {
// Range-check the chunks to be less than `chunk_size`. // Range-check the chunks to be less than `chunk_size`.
let first_product = (0..chunk_size) let first_product: F = (0..chunk_size)
.map(|x| first_chunks[i] - F::from_canonical_usize(x)) .map(|x| first_chunks[i] - F::from_canonical_usize(x))
.product(); .product();
let second_product = (0..chunk_size) let second_product: F = (0..chunk_size)
.map(|x| second_chunks[i] - F::from_canonical_usize(x)) .map(|x| second_chunks[i] - F::from_canonical_usize(x))
.product(); .product();
constraints.push(first_product); constraints.push(first_product);
@ -205,11 +225,22 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
constraints.push(most_significant_diff - most_significant_diff_so_far); constraints.push(most_significant_diff - most_significant_diff_so_far);
// Range check `most_significant_diff` to be less than `chunk_size`. let most_significant_diff_bits: Vec<F> = (0..self.chunk_bits() + 1)
let product = (0..chunk_size) .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)])
.map(|x| most_significant_diff - F::from_canonical_usize(x)) .collect();
.product();
constraints.push(product); // Range-check the bits.
for &bit in &most_significant_diff_bits {
constraints.push(bit * (F::ONE - bit));
}
let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO);
let two_n = F::from_canonical_u64(1 << self.chunk_bits());
constraints.push((two_n + most_significant_diff) - bits_combined);
// Iff first <= second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1.
let result_bool = vars.local_wires[self.wire_result_bool()];
constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]);
constraints constraints
} }
@ -285,14 +316,29 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
constraints constraints
.push(builder.sub_extension(most_significant_diff, most_significant_diff_so_far)); .push(builder.sub_extension(most_significant_diff, most_significant_diff_so_far));
// Range check `most_significant_diff` to be less than `chunk_size`. let most_significant_diff_bits: Vec<ExtensionTarget<D>> = (0..self.chunk_bits() + 1)
let mut product = builder.one_extension(); .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)])
for x in 0..chunk_size { .collect();
let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x));
let diff = builder.sub_extension(most_significant_diff, x_f); // Range-check the bits.
product = builder.mul_extension(product, diff); for &this_bit in &most_significant_diff_bits {
let inverse = builder.sub_extension(one, this_bit);
constraints.push(builder.mul_extension(this_bit, inverse));
} }
constraints.push(product);
let two = builder.two();
let bits_combined =
reduce_with_powers_ext_recursive(builder, &most_significant_diff_bits, two);
let two_n =
builder.constant_extension(F::Extension::from_canonical_u64(1 << self.chunk_bits()));
let sum = builder.add_extension(two_n, most_significant_diff);
constraints.push(builder.sub_extension(sum, bits_combined));
// Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1.
let result_bool = vars.local_wires[self.wire_result_bool()];
constraints.push(
builder.sub_extension(result_bool, most_significant_diff_bits[self.chunk_bits()]),
);
constraints constraints
} }
@ -310,7 +356,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
} }
fn num_wires(&self) -> usize { fn num_wires(&self) -> usize {
self.wire_intermediate_value(self.num_chunks - 1) + 1 4 + 5 * self.num_chunks + (self.chunk_bits() + 1)
} }
fn num_constants(&self) -> usize { fn num_constants(&self) -> usize {
@ -322,7 +368,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
} }
fn num_constraints(&self) -> usize { fn num_constraints(&self) -> usize {
4 + 5 * self.num_chunks 6 + 5 * self.num_chunks + self.chunk_bits()
} }
} }
@ -336,10 +382,10 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ComparisonGenerato
fn dependencies(&self) -> Vec<Target> { fn dependencies(&self) -> Vec<Target> {
let local_target = |input| Target::wire(self.gate_index, input); let local_target = |input| Target::wire(self.gate_index, input);
let mut deps = Vec::new(); vec![
deps.push(local_target(self.gate.wire_first_input())); local_target(self.gate.wire_first_input()),
deps.push(local_target(self.gate.wire_second_input())); local_target(self.gate.wire_second_input()),
deps ]
} }
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) { fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
@ -356,7 +402,7 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ComparisonGenerato
let first_input_u64 = first_input.to_canonical_u64(); let first_input_u64 = first_input.to_canonical_u64();
let second_input_u64 = second_input.to_canonical_u64(); let second_input_u64 = second_input.to_canonical_u64();
debug_assert!(first_input_u64 < second_input_u64); let result = F::from_canonical_usize((first_input_u64 <= second_input_u64) as usize);
let chunk_size = 1 << self.gate.chunk_bits(); let chunk_size = 1 << self.gate.chunk_bits();
let first_input_chunks: Vec<F> = (0..self.gate.num_chunks) let first_input_chunks: Vec<F> = (0..self.gate.num_chunks)
@ -395,6 +441,22 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ComparisonGenerato
} }
let most_significant_diff = most_significant_diff_so_far; let most_significant_diff = most_significant_diff_so_far;
let two_n = F::from_canonical_usize(1 << self.gate.chunk_bits());
let two_n_plus_msd = (two_n + most_significant_diff).to_canonical_u64();
let msd_bits_u64: Vec<u64> = (0..self.gate.chunk_bits() + 1)
.scan(two_n_plus_msd, |acc, _| {
let tmp = *acc % 2;
*acc /= 2;
Some(tmp)
})
.collect();
let msd_bits: Vec<F> = msd_bits_u64
.iter()
.map(|x| F::from_canonical_u64(*x))
.collect();
out_buffer.set_wire(local_wire(self.gate.wire_result_bool()), result);
out_buffer.set_wire( out_buffer.set_wire(
local_wire(self.gate.wire_most_significant_diff()), local_wire(self.gate.wire_most_significant_diff()),
most_significant_diff, most_significant_diff,
@ -418,6 +480,12 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ComparisonGenerato
intermediate_values[i], intermediate_values[i],
); );
} }
for i in 0..self.gate.chunk_bits() + 1 {
out_buffer.set_wire(
local_wire(self.gate.wire_most_significant_diff_bit(i)),
msd_bits[i],
);
}
} }
} }
@ -451,17 +519,20 @@ mod tests {
assert_eq!(gate.wire_first_input(), 0); assert_eq!(gate.wire_first_input(), 0);
assert_eq!(gate.wire_second_input(), 1); assert_eq!(gate.wire_second_input(), 1);
assert_eq!(gate.wire_most_significant_diff(), 2); assert_eq!(gate.wire_result_bool(), 2);
assert_eq!(gate.wire_first_chunk_val(0), 3); assert_eq!(gate.wire_most_significant_diff(), 3);
assert_eq!(gate.wire_first_chunk_val(4), 7); assert_eq!(gate.wire_first_chunk_val(0), 4);
assert_eq!(gate.wire_second_chunk_val(0), 8); assert_eq!(gate.wire_first_chunk_val(4), 8);
assert_eq!(gate.wire_second_chunk_val(4), 12); assert_eq!(gate.wire_second_chunk_val(0), 9);
assert_eq!(gate.wire_equality_dummy(0), 13); assert_eq!(gate.wire_second_chunk_val(4), 13);
assert_eq!(gate.wire_equality_dummy(4), 17); assert_eq!(gate.wire_equality_dummy(0), 14);
assert_eq!(gate.wire_chunks_equal(0), 18); assert_eq!(gate.wire_equality_dummy(4), 18);
assert_eq!(gate.wire_chunks_equal(4), 22); assert_eq!(gate.wire_chunks_equal(0), 19);
assert_eq!(gate.wire_intermediate_value(0), 23); assert_eq!(gate.wire_chunks_equal(4), 23);
assert_eq!(gate.wire_intermediate_value(4), 27); assert_eq!(gate.wire_intermediate_value(0), 24);
assert_eq!(gate.wire_intermediate_value(4), 28);
assert_eq!(gate.wire_most_significant_diff_bit(0), 29);
assert_eq!(gate.wire_most_significant_diff_bit(8), 37);
} }
#[test] #[test]
@ -501,6 +572,8 @@ mod tests {
let first_input_u64 = first_input.to_canonical_u64(); let first_input_u64 = first_input.to_canonical_u64();
let second_input_u64 = second_input.to_canonical_u64(); let second_input_u64 = second_input.to_canonical_u64();
let result_bool = F::from_bool(first_input_u64 <= second_input_u64);
let chunk_size = 1 << chunk_bits; let chunk_size = 1 << chunk_bits;
let mut first_input_chunks: Vec<F> = (0..num_chunks) let mut first_input_chunks: Vec<F> = (0..num_chunks)
.scan(first_input_u64, |acc, _| { .scan(first_input_u64, |acc, _| {
@ -538,20 +611,32 @@ mod tests {
} }
let most_significant_diff = most_significant_diff_so_far; let most_significant_diff = most_significant_diff_so_far;
let two_n_plus_msd =
(1 << chunk_bits) as u64 + most_significant_diff.to_canonical_u64();
let mut msd_bits: Vec<F> = (0..chunk_bits + 1)
.scan(two_n_plus_msd, |acc, _| {
let tmp = *acc % 2;
*acc /= 2;
Some(F::from_canonical_u64(tmp))
})
.collect();
v.push(first_input); v.push(first_input);
v.push(second_input); v.push(second_input);
v.push(result_bool);
v.push(most_significant_diff); v.push(most_significant_diff);
v.append(&mut first_input_chunks); v.append(&mut first_input_chunks);
v.append(&mut second_input_chunks); v.append(&mut second_input_chunks);
v.append(&mut equality_dummies); v.append(&mut equality_dummies);
v.append(&mut chunks_equal); v.append(&mut chunks_equal);
v.append(&mut intermediate_values); v.append(&mut intermediate_values);
v.append(&mut msd_bits);
v.iter().map(|&x| x.into()).collect::<Vec<_>>() v.iter().map(|&x| x.into()).collect::<Vec<_>>()
}; };
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let max: u64 = 1 << num_bits - 1; let max: u64 = 1 << (num_bits - 1);
let first_input_u64 = rng.gen_range(0..max); let first_input_u64 = rng.gen_range(0..max);
let second_input_u64 = { let second_input_u64 = {
let mut val = rng.gen_range(0..max); let mut val = rng.gen_range(0..max);

View File

@ -337,9 +337,8 @@ mod tests {
.map(|b| F::from_canonical_u64(*b)) .map(|b| F::from_canonical_u64(*b))
.collect(); .collect();
let mut v = Vec::new(); let mut v = vec![base];
v.push(base); v.extend(power_bits_f);
v.extend(power_bits_f.clone());
let mut intermediate_values = Vec::new(); let mut intermediate_values = Vec::new();
let mut current_intermediate_value = F::ONE; let mut current_intermediate_value = F::ONE;

View File

@ -10,7 +10,7 @@ use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::config::GenericConfig; use crate::plonk::config::GenericConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::plonk::verifier::verify; use crate::plonk::verifier::verify;
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::util::{log2_ceil, transpose}; use crate::util::{log2_ceil, transpose};
const WITNESS_SIZE: usize = 1 << 5; const WITNESS_SIZE: usize = 1 << 5;

View File

@ -1,4 +1,4 @@
use log::info; use log::debug;
use crate::field::extension_field::Extendable; use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField; use crate::field::field_types::RichField;
@ -86,7 +86,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Tree<GateRef<F, D>> {
} }
} }
} }
info!( debug!(
"Found tree with max degree {} and {} constants wires in {:.4}s.", "Found tree with max degree {} and {} constants wires in {:.4}s.",
best_degree, best_degree,
best_num_constants, best_num_constants,
@ -221,12 +221,17 @@ impl<F: RichField + Extendable<D>, const D: usize> Tree<GateRef<F, D>> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use log::info;
use super::*; use super::*;
use crate::field::goldilocks_field::GoldilocksField;
use crate::gadgets::interpolation::InterpolationGate;
use crate::gates::arithmetic_extension::ArithmeticExtensionGate;
use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gates::arithmetic::ArithmeticExtensionGate;
use crate::gates::base_sum::BaseSumGate; use crate::gates::base_sum::BaseSumGate;
use crate::gates::constant::ConstantGate; use crate::gates::constant::ConstantGate;
use crate::gates::gmimc::GMiMCGate; use crate::gates::gmimc::GMiMCGate;
use crate::gates::interpolation::InterpolationGate; use crate::gates::interpolation::HighDegreeInterpolationGate;
use crate::gates::noop::NoopGate; use crate::gates::noop::NoopGate;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
@ -243,7 +248,7 @@ mod tests {
GateRef::new(ArithmeticExtensionGate { num_ops: 4 }), GateRef::new(ArithmeticExtensionGate { num_ops: 4 }),
GateRef::new(BaseSumGate::<4>::new(4)), GateRef::new(BaseSumGate::<4>::new(4)),
GateRef::new(GMiMCGate::<F, D, 12>::new()), GateRef::new(GMiMCGate::<F, D, 12>::new()),
GateRef::new(InterpolationGate::new(2)), GateRef::new(HighDegreeInterpolationGate::new(2)),
]; ];
let (tree, _, _) = Tree::from_gates(gates.clone()); let (tree, _, _) = Tree::from_gates(gates.clone());

View File

@ -318,8 +318,6 @@ impl<F: Extendable<D> + GMiMC<WIDTH>, const D: usize, const WIDTH: usize> Simple
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::convert::TryInto;
use anyhow::Result; use anyhow::Result;
use crate::field::field_types::Field; use crate::field::field_types::Field;

View File

@ -1,4 +1,3 @@
use std::convert::TryInto;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::ops::Range; use std::ops::Range;
@ -252,8 +251,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F> for Insert
let local_targets = |inputs: Range<usize>| inputs.map(local_target); let local_targets = |inputs: Range<usize>| inputs.map(local_target);
let mut deps = Vec::new(); let mut deps = vec![local_target(self.gate.wires_insertion_index())];
deps.push(local_target(self.gate.wires_insertion_index()));
deps.extend(local_targets(self.gate.wires_element_to_insert())); deps.extend(local_targets(self.gate.wires_element_to_insert()));
for i in 0..self.gate.vec_size { for i in 0..self.gate.vec_size {
deps.extend(local_targets(self.gate.wires_original_list_item(i))); deps.extend(local_targets(self.gate.wires_original_list_item(i)));
@ -292,7 +290,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F> for Insert
vec_size vec_size
); );
let mut new_vec = orig_vec.clone(); let mut new_vec = orig_vec;
new_vec.insert(insertion_index, to_insert); new_vec.insert(insertion_index, to_insert);
let mut equality_dummy_vals = Vec::new(); let mut equality_dummy_vals = Vec::new();
@ -377,14 +375,13 @@ mod tests {
fn get_wires(orig_vec: Vec<FF>, insertion_index: usize, element_to_insert: FF) -> Vec<FF> { fn get_wires(orig_vec: Vec<FF>, insertion_index: usize, element_to_insert: FF) -> Vec<FF> {
let vec_size = orig_vec.len(); let vec_size = orig_vec.len();
let mut v = Vec::new(); let mut v = vec![F::from_canonical_usize(insertion_index)];
v.push(F::from_canonical_usize(insertion_index));
v.extend(element_to_insert.0); v.extend(element_to_insert.0);
for j in 0..vec_size { for j in 0..vec_size {
v.extend(orig_vec[j].0); v.extend(orig_vec[j].0);
} }
let mut new_vec = orig_vec.clone(); let mut new_vec = orig_vec;
new_vec.insert(insertion_index, element_to_insert); new_vec.insert(insertion_index, element_to_insert);
let mut equality_dummy_vals = Vec::new(); let mut equality_dummy_vals = Vec::new();
for i in 0..=vec_size { for i in 0..=vec_size {

View File

@ -1,4 +1,3 @@
use std::convert::TryInto;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::ops::Range; use std::ops::Range;
@ -6,6 +5,7 @@ use crate::field::extension_field::algebra::PolynomialCoeffsAlgebra;
use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::interpolation::interpolant; use crate::field::interpolation::interpolant;
use crate::gadgets::interpolation::InterpolationGate;
use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget;
use crate::gates::gate::Gate; use crate::gates::gate::Gate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
@ -14,19 +14,20 @@ use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness}; use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::polynomial::polynomial::PolynomialCoeffs; use crate::polynomial::PolynomialCoeffs;
/// Interpolates a polynomial, whose points are a (base field) coset of the multiplicative subgroup /// Interpolation gate with constraints of degree at most `1<<subgroup_bits`.
/// with the given size, and whose values are extension field elements, given by input wires. /// `eval_unfiltered_recursively` uses less gates than `LowDegreeInterpolationGate`.
/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point. #[derive(Copy, Clone, Debug)]
#[derive(Clone, Debug)] pub(crate) struct HighDegreeInterpolationGate<F: RichField + Extendable<D>, const D: usize> {
pub(crate) struct InterpolationGate<F: Extendable<D>, const D: usize> {
pub subgroup_bits: usize, pub subgroup_bits: usize,
_phantom: PhantomData<F>, _phantom: PhantomData<F>,
} }
impl<F: Extendable<D>, const D: usize> InterpolationGate<F, D> { impl<F: Extendable<D>, const D: usize> InterpolationGate<F, D>
pub fn new(subgroup_bits: usize) -> Self { for HighDegreeInterpolationGate<F, D>
{
fn new(subgroup_bits: usize) -> Self {
Self { Self {
subgroup_bits, subgroup_bits,
_phantom: PhantomData, _phantom: PhantomData,
@ -36,60 +37,9 @@ impl<F: Extendable<D>, const D: usize> InterpolationGate<F, D> {
fn num_points(&self) -> usize { fn num_points(&self) -> usize {
1 << self.subgroup_bits 1 << self.subgroup_bits
} }
/// Wire index of the coset shift.
pub fn wire_shift(&self) -> usize {
0
}
fn start_values(&self) -> usize {
1
}
/// Wire indices of the `i`th interpolant value.
pub fn wires_value(&self, i: usize) -> Range<usize> {
debug_assert!(i < self.num_points());
let start = self.start_values() + i * D;
start..start + D
}
fn start_evaluation_point(&self) -> usize {
self.start_values() + self.num_points() * D
}
/// Wire indices of the point to evaluate the interpolant at.
pub fn wires_evaluation_point(&self) -> Range<usize> {
let start = self.start_evaluation_point();
start..start + D
}
fn start_evaluation_value(&self) -> usize {
self.start_evaluation_point() + D
}
/// Wire indices of the interpolated value.
pub fn wires_evaluation_value(&self) -> Range<usize> {
let start = self.start_evaluation_value();
start..start + D
}
fn start_coeffs(&self) -> usize {
self.start_evaluation_value() + D
}
/// The number of routed wires required in the typical usage of this gate, where the points to
/// interpolate, the evaluation point, and the corresponding value are all routed.
pub(crate) fn num_routed_wires(&self) -> usize {
self.start_coeffs()
}
/// Wire indices of the interpolant's `i`th coefficient.
pub fn wires_coeff(&self, i: usize) -> Range<usize> {
debug_assert!(i < self.num_points());
let start = self.start_coeffs() + i * D;
start..start + D
} }
impl<F: RichField + Extendable<D>, const D: usize> HighDegreeInterpolationGate<F, D> {
/// End of wire indices, exclusive. /// End of wire indices, exclusive.
fn end(&self) -> usize { fn end(&self) -> usize {
self.start_coeffs() + self.num_points() * D self.start_coeffs() + self.num_points() * D
@ -121,14 +71,16 @@ impl<F: Extendable<D>, const D: usize> InterpolationGate<F, D> {
g.powers() g.powers()
.take(size) .take(size)
.map(move |x| { .map(move |x| {
let subgroup_element = builder.constant(x.into()); let subgroup_element = builder.constant(x);
builder.scalar_mul_ext(subgroup_element, shift) builder.scalar_mul_ext(subgroup_element, shift)
}) })
.collect() .collect()
} }
} }
impl<F: Extendable<D>, const D: usize> Gate<F, D> for InterpolationGate<F, D> { impl<F: Extendable<D>, const D: usize> Gate<F, D>
for HighDegreeInterpolationGate<F, D>
{
fn id(&self) -> String { fn id(&self) -> String {
format!("{:?}<D={}>", self, D) format!("{:?}<D={}>", self, D)
} }
@ -221,7 +173,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InterpolationGate<F, D> {
) -> Vec<Box<dyn WitnessGenerator<F>>> { ) -> Vec<Box<dyn WitnessGenerator<F>>> {
let gen = InterpolationGenerator::<F, D> { let gen = InterpolationGenerator::<F, D> {
gate_index, gate_index,
gate: self.clone(), gate: *self,
_phantom: PhantomData, _phantom: PhantomData,
}; };
vec![Box::new(gen.adapter())] vec![Box::new(gen.adapter())]
@ -251,7 +203,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InterpolationGate<F, D> {
#[derive(Debug)] #[derive(Debug)]
struct InterpolationGenerator<F: Extendable<D>, const D: usize> { struct InterpolationGenerator<F: Extendable<D>, const D: usize> {
gate_index: usize, gate_index: usize,
gate: InterpolationGate<F, D>, gate: HighDegreeInterpolationGate<F, D>,
_phantom: PhantomData<F>, _phantom: PhantomData<F>,
} }
@ -321,17 +273,18 @@ mod tests {
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField; use crate::field::goldilocks_field::GoldilocksField;
use crate::gadgets::interpolation::InterpolationGate;
use crate::gates::gate::Gate; use crate::gates::gate::Gate;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::gates::interpolation::InterpolationGate; use crate::gates::interpolation::HighDegreeInterpolationGate;
use crate::hash::hash_types::HashOut; use crate::hash::hash_types::HashOut;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use crate::plonk::vars::EvaluationVars; use crate::plonk::vars::EvaluationVars;
use crate::polynomial::polynomial::PolynomialCoeffs; use crate::polynomial::PolynomialCoeffs;
#[test] #[test]
fn wire_indices() { fn wire_indices() {
let gate = InterpolationGate::<GoldilocksField, 4> { let gate = HighDegreeInterpolationGate::<GoldilocksField, 4> {
subgroup_bits: 1, subgroup_bits: 1,
_phantom: PhantomData, _phantom: PhantomData,
}; };
@ -350,7 +303,7 @@ mod tests {
#[test] #[test]
fn low_degree() { fn low_degree() {
test_low_degree::<GoldilocksField, _, 4>(InterpolationGate::new(2)); test_low_degree::<GoldilocksField, _, 4>(HighDegreeInterpolationGate::new(2));
} }
#[test] #[test]
@ -358,7 +311,7 @@ mod tests {
const D: usize = 2; const D: usize = 2;
type C = PoseidonGoldilocksConfig; type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F; type F = <C as GenericConfig<D>>::F;
test_eval_fns::<F, C, _, D>(InterpolationGate::new(2)) test_eval_fns::<F, C, _, D>(HighDegreeInterpolationGate::new(2))
} }
#[test] #[test]
@ -370,7 +323,7 @@ mod tests {
/// Returns the local wires for an interpolation gate for given coeffs, points and eval point. /// Returns the local wires for an interpolation gate for given coeffs, points and eval point.
fn get_wires( fn get_wires(
gate: &InterpolationGate<F, D>, gate: &HighDegreeInterpolationGate<F, D>,
shift: F, shift: F,
coeffs: PolynomialCoeffs<FF>, coeffs: PolynomialCoeffs<FF>,
eval_point: FF, eval_point: FF,
@ -392,7 +345,7 @@ mod tests {
let shift = F::rand(); let shift = F::rand();
let coeffs = PolynomialCoeffs::new(vec![FF::rand(), FF::rand()]); let coeffs = PolynomialCoeffs::new(vec![FF::rand(), FF::rand()]);
let eval_point = FF::rand(); let eval_point = FF::rand();
let gate = InterpolationGate::<F, D>::new(1); let gate = HighDegreeInterpolationGate::<F, D>::new(1);
let vars = EvaluationVars { let vars = EvaluationVars {
local_constants: &[], local_constants: &[],
local_wires: &get_wires(&gate, shift, coeffs, eval_point), local_wires: &get_wires(&gate, shift, coeffs, eval_point),

View File

@ -0,0 +1,459 @@
use std::marker::PhantomData;
use std::ops::Range;
use crate::field::extension_field::algebra::PolynomialCoeffsAlgebra;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::{Field, RichField};
use crate::field::interpolation::interpolant;
use crate::gadgets::interpolation::InterpolationGate;
use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget;
use crate::gates::gate::Gate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::polynomial::PolynomialCoeffs;
/// Interpolation gate with constraints of degree 2.
/// `eval_unfiltered_recursively` uses more gates than `HighDegreeInterpolationGate`.
#[derive(Copy, Clone, Debug)]
pub(crate) struct LowDegreeInterpolationGate<F: RichField + Extendable<D>, const D: usize> {
pub subgroup_bits: usize,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize> InterpolationGate<F, D>
for LowDegreeInterpolationGate<F, D>
{
fn new(subgroup_bits: usize) -> Self {
Self {
subgroup_bits,
_phantom: PhantomData,
}
}
fn num_points(&self) -> usize {
1 << self.subgroup_bits
}
}
impl<F: RichField + Extendable<D>, const D: usize> LowDegreeInterpolationGate<F, D> {
/// `powers_shift(i)` is the wire index of `wire_shift^i`.
pub fn powers_shift(&self, i: usize) -> usize {
debug_assert!(0 < i && i < self.num_points());
if i == 1 {
return self.wire_shift();
}
self.end_coeffs() + i - 2
}
/// `powers_evalutation_point(i)` is the wire index of `evalutation_point^i`.
pub fn powers_evaluation_point(&self, i: usize) -> Range<usize> {
debug_assert!(0 < i && i < self.num_points());
if i == 1 {
return self.wires_evaluation_point();
}
let start = self.end_coeffs() + self.num_points() - 2 + (i - 2) * D;
start..start + D
}
/// End of wire indices, exclusive.
fn end(&self) -> usize {
self.powers_evaluation_point(self.num_points() - 1).end
}
/// The domain of the points we're interpolating.
fn coset(&self, shift: F) -> impl Iterator<Item = F> {
let g = F::primitive_root_of_unity(self.subgroup_bits);
let size = 1 << self.subgroup_bits;
// Speed matters here, so we avoid `cyclic_subgroup_coset_known_order` which allocates.
g.powers().take(size).map(move |x| x * shift)
}
}
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for LowDegreeInterpolationGate<F, D> {
fn id(&self) -> String {
format!("{:?}<D={}>", self, D)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let mut constraints = Vec::with_capacity(self.num_constraints());
let coeffs = (0..self.num_points())
.map(|i| vars.get_local_ext_algebra(self.wires_coeff(i)))
.collect::<Vec<_>>();
let mut powers_shift = (1..self.num_points())
.map(|i| vars.local_wires[self.powers_shift(i)])
.collect::<Vec<_>>();
let shift = powers_shift[0];
for i in 1..self.num_points() - 1 {
constraints.push(powers_shift[i - 1] * shift - powers_shift[i]);
}
powers_shift.insert(0, F::Extension::ONE);
// `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient.
// Then, `altered(w^i) = original(shift*w^i)`.
let altered_coeffs = coeffs
.iter()
.zip(powers_shift)
.map(|(&c, p)| c.scalar_mul(p))
.collect::<Vec<_>>();
let interpolant = PolynomialCoeffsAlgebra::new(coeffs);
let altered_interpolant = PolynomialCoeffsAlgebra::new(altered_coeffs);
for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits)
.into_iter()
.enumerate()
{
let value = vars.get_local_ext_algebra(self.wires_value(i));
let computed_value = altered_interpolant.eval_base(point);
constraints.extend(&(value - computed_value).to_basefield_array());
}
let evaluation_point_powers = (1..self.num_points())
.map(|i| vars.get_local_ext_algebra(self.powers_evaluation_point(i)))
.collect::<Vec<_>>();
let evaluation_point = evaluation_point_powers[0];
for i in 1..self.num_points() - 1 {
constraints.extend(
(evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i])
.to_basefield_array(),
);
}
let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value());
let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers);
constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array());
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints());
let coeffs = (0..self.num_points())
.map(|i| vars.get_local_ext(self.wires_coeff(i)))
.collect::<Vec<_>>();
let mut powers_shift = (1..self.num_points())
.map(|i| vars.local_wires[self.powers_shift(i)])
.collect::<Vec<_>>();
let shift = powers_shift[0];
for i in 1..self.num_points() - 1 {
constraints.push(powers_shift[i - 1] * shift - powers_shift[i]);
}
powers_shift.insert(0, F::ONE);
// `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient.
// Then, `altered(w^i) = original(shift*w^i)`.
let altered_coeffs = coeffs
.iter()
.zip(powers_shift)
.map(|(&c, p)| c.scalar_mul(p))
.collect::<Vec<_>>();
let interpolant = PolynomialCoeffs::new(coeffs);
let altered_interpolant = PolynomialCoeffs::new(altered_coeffs);
for (i, point) in F::two_adic_subgroup(self.subgroup_bits)
.into_iter()
.enumerate()
{
let value = vars.get_local_ext(self.wires_value(i));
let computed_value = altered_interpolant.eval_base(point);
constraints.extend(&(value - computed_value).to_basefield_array());
}
let evaluation_point_powers = (1..self.num_points())
.map(|i| vars.get_local_ext(self.powers_evaluation_point(i)))
.collect::<Vec<_>>();
let evaluation_point = evaluation_point_powers[0];
for i in 1..self.num_points() - 1 {
constraints.extend(
(evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i])
.to_basefield_array(),
);
}
let evaluation_value = vars.get_local_ext(self.wires_evaluation_value());
let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers);
constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array());
constraints
}
fn eval_unfiltered_recursively(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let mut constraints = Vec::with_capacity(self.num_constraints());
let coeffs = (0..self.num_points())
.map(|i| vars.get_local_ext_algebra(self.wires_coeff(i)))
.collect::<Vec<_>>();
let mut powers_shift = (1..self.num_points())
.map(|i| vars.local_wires[self.powers_shift(i)])
.collect::<Vec<_>>();
let shift = powers_shift[0];
for i in 1..self.num_points() - 1 {
constraints.push(builder.mul_sub_extension(
powers_shift[i - 1],
shift,
powers_shift[i],
));
}
powers_shift.insert(0, builder.one_extension());
// `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient.
// Then, `altered(w^i) = original(shift*w^i)`.
let altered_coeffs = coeffs
.iter()
.zip(powers_shift)
.map(|(&c, p)| builder.scalar_mul_ext_algebra(p, c))
.collect::<Vec<_>>();
let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs);
let altered_interpolant = PolynomialCoeffsExtAlgebraTarget(altered_coeffs);
for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits)
.into_iter()
.enumerate()
{
let value = vars.get_local_ext_algebra(self.wires_value(i));
let point = builder.constant_extension(point);
let computed_value = altered_interpolant.eval_scalar(builder, point);
constraints.extend(
&builder
.sub_ext_algebra(value, computed_value)
.to_ext_target_array(),
);
}
let evaluation_point_powers = (1..self.num_points())
.map(|i| vars.get_local_ext_algebra(self.powers_evaluation_point(i)))
.collect::<Vec<_>>();
let evaluation_point = evaluation_point_powers[0];
for i in 1..self.num_points() - 1 {
let neg_one_ext = builder.neg_one_extension();
let neg_new_power =
builder.scalar_mul_ext_algebra(neg_one_ext, evaluation_point_powers[i]);
let constraint = builder.mul_add_ext_algebra(
evaluation_point,
evaluation_point_powers[i - 1],
neg_new_power,
);
constraints.extend(constraint.to_ext_target_array());
}
let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value());
let computed_evaluation_value =
interpolant.eval_with_powers(builder, &evaluation_point_powers);
// let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point());
// let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value());
// let computed_evaluation_value = interpolant.eval(builder, evaluation_point);
constraints.extend(
&builder
.sub_ext_algebra(evaluation_value, computed_evaluation_value)
.to_ext_target_array(),
);
constraints
}
fn generators(
&self,
gate_index: usize,
_local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
let gen = InterpolationGenerator::<F, D> {
gate_index,
gate: *self,
_phantom: PhantomData,
};
vec![Box::new(gen.adapter())]
}
fn num_wires(&self) -> usize {
self.end()
}
fn num_constants(&self) -> usize {
0
}
fn degree(&self) -> usize {
2
}
fn num_constraints(&self) -> usize {
// `num_points * D` constraints to check for consistency between the coefficients and the
// point-value pairs, plus `D` constraints for the evaluation value, plus `(D+1)*(num_points-2)`
// to check power constraints for evaluation point and shift.
self.num_points() * D + D + (D + 1) * (self.num_points() - 2)
}
}
#[derive(Debug)]
struct InterpolationGenerator<F: RichField + Extendable<D>, const D: usize> {
gate_index: usize,
gate: LowDegreeInterpolationGate<F, D>,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
for InterpolationGenerator<F, D>
{
fn dependencies(&self) -> Vec<Target> {
let local_target = |input| {
Target::Wire(Wire {
gate: self.gate_index,
input,
})
};
let local_targets = |inputs: Range<usize>| inputs.map(local_target);
let num_points = self.gate.num_points();
let mut deps = Vec::with_capacity(1 + D + num_points * D);
deps.push(local_target(self.gate.wire_shift()));
deps.extend(local_targets(self.gate.wires_evaluation_point()));
for i in 0..num_points {
deps.extend(local_targets(self.gate.wires_value(i)));
}
deps
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let local_wire = |input| Wire {
gate: self.gate_index,
input,
};
let get_local_wire = |input| witness.get_wire(local_wire(input));
let get_local_ext = |wire_range: Range<usize>| {
debug_assert_eq!(wire_range.len(), D);
let values = wire_range.map(get_local_wire).collect::<Vec<_>>();
let arr = values.try_into().unwrap();
F::Extension::from_basefield_array(arr)
};
let wire_shift = get_local_wire(self.gate.wire_shift());
for (i, power) in wire_shift
.powers()
.take(self.gate.num_points())
.enumerate()
.skip(2)
{
out_buffer.set_wire(local_wire(self.gate.powers_shift(i)), power);
}
// Compute the interpolant.
let points = self.gate.coset(wire_shift);
let points = points
.into_iter()
.enumerate()
.map(|(i, point)| (point.into(), get_local_ext(self.gate.wires_value(i))))
.collect::<Vec<_>>();
let interpolant = interpolant(&points);
for (i, &coeff) in interpolant.coeffs.iter().enumerate() {
let wires = self.gate.wires_coeff(i).map(local_wire);
out_buffer.set_ext_wires(wires, coeff);
}
let evaluation_point = get_local_ext(self.gate.wires_evaluation_point());
for (i, power) in evaluation_point
.powers()
.take(self.gate.num_points())
.enumerate()
.skip(2)
{
out_buffer.set_extension_target(
ExtensionTarget::from_range(self.gate_index, self.gate.powers_evaluation_point(i)),
power,
);
}
let evaluation_value = interpolant.eval(evaluation_point);
let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire);
out_buffer.set_ext_wires(evaluation_value_wires, evaluation_value);
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use crate::field::extension_field::quadratic::QuadraticExtension;
use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField;
use crate::gadgets::interpolation::InterpolationGate;
use crate::gates::gate::Gate;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate;
use crate::hash::hash_types::HashOut;
use crate::plonk::vars::EvaluationVars;
use crate::polynomial::PolynomialCoeffs;
#[test]
fn low_degree() {
test_low_degree::<GoldilocksField, _, 4>(LowDegreeInterpolationGate::new(4));
}
#[test]
fn eval_fns() -> Result<()> {
test_eval_fns::<GoldilocksField, _, 4>(LowDegreeInterpolationGate::new(4))
}
#[test]
fn test_gate_constraint() {
type F = GoldilocksField;
type FF = QuadraticExtension<GoldilocksField>;
const D: usize = 2;
/// Returns the local wires for an interpolation gate for given coeffs, points and eval point.
fn get_wires(
gate: &LowDegreeInterpolationGate<F, D>,
shift: F,
coeffs: PolynomialCoeffs<FF>,
eval_point: FF,
) -> Vec<FF> {
let points = gate.coset(shift);
let mut v = vec![shift];
for x in points {
v.extend(coeffs.eval(x.into()).0);
}
v.extend(eval_point.0);
v.extend(coeffs.eval(eval_point).0);
for i in 0..coeffs.len() {
v.extend(coeffs.coeffs[i].0);
}
v.extend(shift.powers().skip(2).take(gate.num_points() - 2));
v.extend(
eval_point
.powers()
.skip(2)
.take(gate.num_points() - 2)
.flat_map(|ff| ff.0),
);
v.iter().map(|&x| x.into()).collect::<Vec<_>>()
}
// Get a working row for LowDegreeInterpolationGate.
let subgroup_bits = 4;
let shift = F::rand();
let coeffs = PolynomialCoeffs::new(FF::rand_vec(1 << subgroup_bits));
let eval_point = FF::rand();
let gate = LowDegreeInterpolationGate::<F, D>::new(subgroup_bits);
let vars = EvaluationVars {
local_constants: &[],
local_wires: &get_wires(&gate, shift, coeffs, eval_point),
public_inputs_hash: &HashOut::rand(),
};
assert!(
gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()),
"Gate constraints are not satisfied."
);
}
}

View File

@ -1,8 +1,10 @@
// Gates have `new` methods that return `GateRef`s. // Gates have `new` methods that return `GateRef`s.
#![allow(clippy::new_ret_no_self)] #![allow(clippy::new_ret_no_self)]
pub mod arithmetic; pub mod arithmetic_base;
pub mod arithmetic_extension;
pub mod arithmetic_u32; pub mod arithmetic_u32;
pub mod assert_le;
pub mod base_sum; pub mod base_sum;
pub mod comparison; pub mod comparison;
pub mod constant; pub mod constant;
@ -12,12 +14,16 @@ pub mod gate_tree;
pub mod gmimc; pub mod gmimc;
pub mod insertion; pub mod insertion;
pub mod interpolation; pub mod interpolation;
pub mod low_degree_interpolation;
pub mod multiplication_extension;
pub mod noop; pub mod noop;
pub mod poseidon; pub mod poseidon;
pub(crate) mod poseidon_mds; pub(crate) mod poseidon_mds;
pub(crate) mod public_input; pub(crate) mod public_input;
pub mod random_access; pub mod random_access;
pub mod reducing; pub mod reducing;
pub mod reducing_extension;
pub mod subtraction_u32;
pub mod switch; pub mod switch;
#[cfg(test)] #[cfg(test)]

View File

@ -0,0 +1,204 @@
use std::ops::Range;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::extension_field::FieldExtension;
use crate::field::field_types::RichField;
use crate::gates::gate::Gate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// A gate which can perform a weighted multiplication, i.e. `result = c0 x y`. If the config
/// supports enough routed wires, it can support several such operations in one gate.
#[derive(Debug)]
pub struct MulExtensionGate<const D: usize> {
/// Number of multiplications performed by the gate.
pub num_ops: usize,
}
impl<const D: usize> MulExtensionGate<D> {
pub fn new_from_config(config: &CircuitConfig) -> Self {
Self {
num_ops: Self::num_ops(config),
}
}
/// Determine the maximum number of operations that can fit in one gate for the given config.
pub(crate) fn num_ops(config: &CircuitConfig) -> usize {
let wires_per_op = 3 * D;
config.num_routed_wires / wires_per_op
}
pub fn wires_ith_multiplicand_0(i: usize) -> Range<usize> {
3 * D * i..3 * D * i + D
}
pub fn wires_ith_multiplicand_1(i: usize) -> Range<usize> {
3 * D * i + D..3 * D * i + 2 * D
}
pub fn wires_ith_output(i: usize) -> Range<usize> {
3 * D * i + 2 * D..3 * D * i + 3 * D
}
}
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for MulExtensionGate<D> {
fn id(&self) -> String {
format!("{:?}", self)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let const_0 = vars.local_constants[0];
let mut constraints = Vec::new();
for i in 0..self.num_ops {
let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i));
let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i));
let output = vars.get_local_ext_algebra(Self::wires_ith_output(i));
let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(const_0);
constraints.extend((output - computed_output).to_basefield_array());
}
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let const_0 = vars.local_constants[0];
let mut constraints = Vec::new();
for i in 0..self.num_ops {
let multiplicand_0 = vars.get_local_ext(Self::wires_ith_multiplicand_0(i));
let multiplicand_1 = vars.get_local_ext(Self::wires_ith_multiplicand_1(i));
let output = vars.get_local_ext(Self::wires_ith_output(i));
let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(const_0);
constraints.extend((output - computed_output).to_basefield_array());
}
constraints
}
fn eval_unfiltered_recursively(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let const_0 = vars.local_constants[0];
let mut constraints = Vec::new();
for i in 0..self.num_ops {
let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i));
let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i));
let output = vars.get_local_ext_algebra(Self::wires_ith_output(i));
let computed_output = {
let mul = builder.mul_ext_algebra(multiplicand_0, multiplicand_1);
builder.scalar_mul_ext_algebra(const_0, mul)
};
let diff = builder.sub_ext_algebra(output, computed_output);
constraints.extend(diff.to_ext_target_array());
}
constraints
}
fn generators(
&self,
gate_index: usize,
local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
(0..self.num_ops)
.map(|i| {
let g: Box<dyn WitnessGenerator<F>> = Box::new(
MulExtensionGenerator {
gate_index,
const_0: local_constants[0],
i,
}
.adapter(),
);
g
})
.collect::<Vec<_>>()
}
fn num_wires(&self) -> usize {
self.num_ops * 3 * D
}
fn num_constants(&self) -> usize {
1
}
fn degree(&self) -> usize {
3
}
fn num_constraints(&self) -> usize {
self.num_ops * D
}
}
#[derive(Clone, Debug)]
struct MulExtensionGenerator<F: RichField + Extendable<D>, const D: usize> {
gate_index: usize,
const_0: F,
i: usize,
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
for MulExtensionGenerator<F, D>
{
fn dependencies(&self) -> Vec<Target> {
MulExtensionGate::<D>::wires_ith_multiplicand_0(self.i)
.chain(MulExtensionGate::<D>::wires_ith_multiplicand_1(self.i))
.map(|i| Target::wire(self.gate_index, i))
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let extract_extension = |range: Range<usize>| -> F::Extension {
let t = ExtensionTarget::from_range(self.gate_index, range);
witness.get_extension_target(t)
};
let multiplicand_0 =
extract_extension(MulExtensionGate::<D>::wires_ith_multiplicand_0(self.i));
let multiplicand_1 =
extract_extension(MulExtensionGate::<D>::wires_ith_multiplicand_1(self.i));
let output_target = ExtensionTarget::from_range(
self.gate_index,
MulExtensionGate::<D>::wires_ith_output(self.i),
);
let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(self.const_0);
out_buffer.set_extension_target(output_target, computed_output)
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::gates::multiplication_extension::MulExtensionGate;
use crate::plonk::circuit_data::CircuitConfig;
#[test]
fn low_degree() {
let gate = MulExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config());
test_low_degree::<GoldilocksField, _, 4>(gate);
}
#[test]
fn eval_fns() -> Result<()> {
let gate = MulExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config());
test_eval_fns::<GoldilocksField, _, 4>(gate)
}
}

View File

@ -1,4 +1,3 @@
use std::convert::TryInto;
use std::marker::PhantomData; use std::marker::PhantomData;
use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::target::ExtensionTarget;
@ -47,44 +46,59 @@ impl<F: Extendable<D>, const D: usize> PoseidonGate<F, D> {
/// is useful for ordering hashes in Merkle proofs. Otherwise, this should be set to 0. /// is useful for ordering hashes in Merkle proofs. Otherwise, this should be set to 0.
pub const WIRE_SWAP: usize = 2 * SPONGE_WIDTH; pub const WIRE_SWAP: usize = 2 * SPONGE_WIDTH;
const START_DELTA: usize = 2 * WIDTH + 1;
/// A wire which stores `swap * (input[i + 4] - input[i])`; used to compute the swapped inputs.
fn wire_delta(i: usize) -> usize {
assert!(i < 4);
Self::START_DELTA + i
}
const START_FULL_0: usize = Self::START_DELTA + 4;
/// A wire which stores the input of the `i`-th S-box of the `round`-th round of the first set /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the first set
/// of full rounds. /// of full rounds.
fn wire_full_sbox_0(round: usize, i: usize) -> usize { fn wire_full_sbox_0(round: usize, i: usize) -> usize {
debug_assert!(
round != 0,
"First round S-box inputs are not stored as wires"
);
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH); debug_assert!(i < WIDTH);
2 * SPONGE_WIDTH + 1 + SPONGE_WIDTH * round + i Self::START_FULL_0 + WIDTH * (round - 1) + i
} }
const START_PARTIAL: usize = Self::START_FULL_0 + WIDTH * (poseidon::HALF_N_FULL_ROUNDS - 1);
/// A wire which stores the input of the S-box of the `round`-th round of the partial rounds. /// A wire which stores the input of the S-box of the `round`-th round of the partial rounds.
fn wire_partial_sbox(round: usize) -> usize { fn wire_partial_sbox(round: usize) -> usize {
debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); debug_assert!(round < poseidon::N_PARTIAL_ROUNDS);
2 * SPONGE_WIDTH + 1 + SPONGE_WIDTH * poseidon::HALF_N_FULL_ROUNDS + round Self::START_PARTIAL + round
} }
const START_FULL_1: usize = Self::START_PARTIAL + poseidon::N_PARTIAL_ROUNDS;
/// A wire which stores the input of the `i`-th S-box of the `round`-th round of the second set /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the second set
/// of full rounds. /// of full rounds.
fn wire_full_sbox_1(round: usize, i: usize) -> usize { fn wire_full_sbox_1(round: usize, i: usize) -> usize {
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH); debug_assert!(i < WIDTH);
2 * SPONGE_WIDTH Self::START_FULL_1 + WIDTH * round + i
+ 1
+ SPONGE_WIDTH * (poseidon::HALF_N_FULL_ROUNDS + round)
+ poseidon::N_PARTIAL_ROUNDS
+ i
} }
/// End of wire indices, exclusive. /// End of wire indices, exclusive.
fn end() -> usize { fn end() -> usize {
2 * SPONGE_WIDTH Self::START_FULL_1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS
+ 1
+ SPONGE_WIDTH * poseidon::N_FULL_ROUNDS_TOTAL
+ poseidon::N_PARTIAL_ROUNDS
} }
} }
impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> { impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize> Gate<F, D>
for PoseidonGate<F, D, WIDTH>
where
[(); WIDTH - 1]:,
{
fn id(&self) -> String { fn id(&self) -> String {
format!("{:?}<SPONGE_WIDTH={}>", self, SPONGE_WIDTH) format!("{:?}<WIDTH={}>", self, WIDTH)
} }
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> { fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
@ -94,69 +108,79 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
let swap = vars.local_wires[Self::WIRE_SWAP]; let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(swap * (swap - F::Extension::ONE)); constraints.push(swap * (swap - F::Extension::ONE));
let mut state = Vec::with_capacity(SPONGE_WIDTH); // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`.
for i in 0..4 { for i in 0..4 {
let a = vars.local_wires[i]; let input_lhs = vars.local_wires[Self::wire_input(i)];
let b = vars.local_wires[i + 4]; let input_rhs = vars.local_wires[Self::wire_input(i + 4)];
state.push(a + swap * (b - a)); let delta_i = vars.local_wires[Self::wire_delta(i)];
} constraints.push(swap * (input_rhs - input_lhs) - delta_i);
for i in 0..4 { }
let a = vars.local_wires[i + 4];
let b = vars.local_wires[i]; // Compute the possibly-swapped input layer.
state.push(a + swap * (b - a)); let mut state = [F::Extension::ZERO; WIDTH];
} for i in 0..4 {
for i in 8..SPONGE_WIDTH { let delta_i = vars.local_wires[Self::wire_delta(i)];
state.push(vars.local_wires[i]); let input_lhs = Self::wire_input(i);
let input_rhs = Self::wire_input(i + 4);
state[i] = vars.local_wires[input_lhs] + delta_i;
state[i + 4] = vars.local_wires[input_rhs] - delta_i;
}
for i in 8..WIDTH {
state[i] = vars.local_wires[Self::wire_input(i)];
} }
let mut state: [F::Extension; SPONGE_WIDTH] = state.try_into().unwrap();
let mut round_ctr = 0; let mut round_ctr = 0;
// First set of full rounds. // First set of full rounds.
for r in 0..poseidon::HALF_N_FULL_ROUNDS { for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon>::constant_layer_field(&mut state, round_ctr); <F as Poseidon<WIDTH>>::constant_layer_field(&mut state, round_ctr);
for i in 0..SPONGE_WIDTH { if r != 0 {
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
constraints.push(state[i] - sbox_in); constraints.push(state[i] - sbox_in);
state[i] = sbox_in; state[i] = sbox_in;
} }
<F as Poseidon>::sbox_layer_field(&mut state); }
state = <F as Poseidon>::mds_layer_field(&state); <F as Poseidon<WIDTH>>::sbox_layer_field(&mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
round_ctr += 1; round_ctr += 1;
} }
// Partial rounds. // Partial rounds.
<F as Poseidon>::partial_first_constant_layer(&mut state); <F as Poseidon<WIDTH>>::partial_first_constant_layer(&mut state);
state = <F as Poseidon>::mds_partial_layer_init(&mut state); state = <F as Poseidon<WIDTH>>::mds_partial_layer_init(&state);
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
constraints.push(state[0] - sbox_in); constraints.push(state[0] - sbox_in);
state[0] = <F as Poseidon>::sbox_monomial(sbox_in); state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
state[0] += state[0] += F::Extension::from_canonical_u64(
F::Extension::from_canonical_u64(<F as Poseidon>::FAST_PARTIAL_ROUND_CONSTANTS[r]); <F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r],
state = <F as Poseidon>::mds_partial_layer_fast_field(&state, r); );
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(&state, r);
} }
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
constraints.push(state[0] - sbox_in); constraints.push(state[0] - sbox_in);
state[0] = <F as Poseidon>::sbox_monomial(sbox_in); state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
state = state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(
<F as Poseidon>::mds_partial_layer_fast_field(&state, poseidon::N_PARTIAL_ROUNDS - 1); &state,
poseidon::N_PARTIAL_ROUNDS - 1,
);
round_ctr += poseidon::N_PARTIAL_ROUNDS; round_ctr += poseidon::N_PARTIAL_ROUNDS;
// Second set of full rounds. // Second set of full rounds.
for r in 0..poseidon::HALF_N_FULL_ROUNDS { for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon>::constant_layer_field(&mut state, round_ctr); <F as Poseidon<WIDTH>>::constant_layer_field(&mut state, round_ctr);
for i in 0..SPONGE_WIDTH { for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)];
constraints.push(state[i] - sbox_in); constraints.push(state[i] - sbox_in);
state[i] = sbox_in; state[i] = sbox_in;
} }
<F as Poseidon>::sbox_layer_field(&mut state); <F as Poseidon<WIDTH>>::sbox_layer_field(&mut state);
state = <F as Poseidon>::mds_layer_field(&state); state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
round_ctr += 1; round_ctr += 1;
} }
for i in 0..SPONGE_WIDTH { for i in 0..WIDTH {
constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]); constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]);
} }
@ -170,67 +194,76 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
let swap = vars.local_wires[Self::WIRE_SWAP]; let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(swap * swap.sub_one()); constraints.push(swap * swap.sub_one());
let mut state = Vec::with_capacity(SPONGE_WIDTH); // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`.
for i in 0..4 { for i in 0..4 {
let a = vars.local_wires[i]; let input_lhs = vars.local_wires[Self::wire_input(i)];
let b = vars.local_wires[i + 4]; let input_rhs = vars.local_wires[Self::wire_input(i + 4)];
state.push(a + swap * (b - a)); let delta_i = vars.local_wires[Self::wire_delta(i)];
} constraints.push(swap * (input_rhs - input_lhs) - delta_i);
for i in 0..4 { }
let a = vars.local_wires[i + 4];
let b = vars.local_wires[i]; // Compute the possibly-swapped input layer.
state.push(a + swap * (b - a)); let mut state = [F::ZERO; WIDTH];
} for i in 0..4 {
for i in 8..SPONGE_WIDTH { let delta_i = vars.local_wires[Self::wire_delta(i)];
state.push(vars.local_wires[i]); let input_lhs = Self::wire_input(i);
let input_rhs = Self::wire_input(i + 4);
state[i] = vars.local_wires[input_lhs] + delta_i;
state[i + 4] = vars.local_wires[input_rhs] - delta_i;
}
for i in 8..WIDTH {
state[i] = vars.local_wires[Self::wire_input(i)];
} }
let mut state: [F; SPONGE_WIDTH] = state.try_into().unwrap();
let mut round_ctr = 0; let mut round_ctr = 0;
// First set of full rounds. // First set of full rounds.
for r in 0..poseidon::HALF_N_FULL_ROUNDS { for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon>::constant_layer(&mut state, round_ctr); <F as Poseidon<WIDTH>>::constant_layer(&mut state, round_ctr);
for i in 0..SPONGE_WIDTH { if r != 0 {
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
constraints.push(state[i] - sbox_in); constraints.push(state[i] - sbox_in);
state[i] = sbox_in; state[i] = sbox_in;
} }
<F as Poseidon>::sbox_layer(&mut state); }
state = <F as Poseidon>::mds_layer(&state); <F as Poseidon<WIDTH>>::sbox_layer(&mut state);
state = <F as Poseidon<WIDTH>>::mds_layer(&state);
round_ctr += 1; round_ctr += 1;
} }
// Partial rounds. // Partial rounds.
<F as Poseidon>::partial_first_constant_layer(&mut state); <F as Poseidon<WIDTH>>::partial_first_constant_layer(&mut state);
state = <F as Poseidon>::mds_partial_layer_init(&mut state); state = <F as Poseidon<WIDTH>>::mds_partial_layer_init(&state);
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
constraints.push(state[0] - sbox_in); constraints.push(state[0] - sbox_in);
state[0] = <F as Poseidon>::sbox_monomial(sbox_in); state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
state[0] += F::from_canonical_u64(<F as Poseidon>::FAST_PARTIAL_ROUND_CONSTANTS[r]); state[0] +=
state = <F as Poseidon>::mds_partial_layer_fast(&state, r); F::from_canonical_u64(<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r]);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast(&state, r);
} }
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
constraints.push(state[0] - sbox_in); constraints.push(state[0] - sbox_in);
state[0] = <F as Poseidon>::sbox_monomial(sbox_in); state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
state = <F as Poseidon>::mds_partial_layer_fast(&state, poseidon::N_PARTIAL_ROUNDS - 1); state =
<F as Poseidon<WIDTH>>::mds_partial_layer_fast(&state, poseidon::N_PARTIAL_ROUNDS - 1);
round_ctr += poseidon::N_PARTIAL_ROUNDS; round_ctr += poseidon::N_PARTIAL_ROUNDS;
// Second set of full rounds. // Second set of full rounds.
for r in 0..poseidon::HALF_N_FULL_ROUNDS { for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon>::constant_layer(&mut state, round_ctr); <F as Poseidon<WIDTH>>::constant_layer(&mut state, round_ctr);
for i in 0..SPONGE_WIDTH { for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)];
constraints.push(state[i] - sbox_in); constraints.push(state[i] - sbox_in);
state[i] = sbox_in; state[i] = sbox_in;
} }
<F as Poseidon>::sbox_layer(&mut state); <F as Poseidon<WIDTH>>::sbox_layer(&mut state);
state = <F as Poseidon>::mds_layer(&state); state = <F as Poseidon<WIDTH>>::mds_layer(&state);
round_ctr += 1; round_ctr += 1;
} }
for i in 0..SPONGE_WIDTH { for i in 0..WIDTH {
constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]); constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]);
} }
@ -244,7 +277,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
) -> Vec<ExtensionTarget<D>> { ) -> Vec<ExtensionTarget<D>> {
// The naive method is more efficient if we have enough routed wires for PoseidonMdsGate. // The naive method is more efficient if we have enough routed wires for PoseidonMdsGate.
let use_mds_gate = let use_mds_gate =
builder.config.num_routed_wires >= PoseidonMdsGate::<F, D>::new().num_wires(); builder.config.num_routed_wires >= PoseidonMdsGate::<F, D, WIDTH>::new().num_wires();
let mut constraints = Vec::with_capacity(self.num_constraints()); let mut constraints = Vec::with_capacity(self.num_constraints());
@ -252,71 +285,73 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
let swap = vars.local_wires[Self::WIRE_SWAP]; let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(builder.mul_sub_extension(swap, swap, swap)); constraints.push(builder.mul_sub_extension(swap, swap, swap));
let mut state = Vec::with_capacity(SPONGE_WIDTH); // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`.
// We need to compute both `if swap {b} else {a}` and `if swap {a} else {b}`.
// We will arithmetize them as
// swap (b - a) + a
// -swap (b - a) + b
// so that `b - a` can be used for both.
let mut state_first_4 = vec![];
let mut state_next_4 = vec![];
for i in 0..4 { for i in 0..4 {
let a = vars.local_wires[i]; let input_lhs = vars.local_wires[Self::wire_input(i)];
let b = vars.local_wires[i + 4]; let input_rhs = vars.local_wires[Self::wire_input(i + 4)];
let delta = builder.sub_extension(b, a); let delta_i = vars.local_wires[Self::wire_delta(i)];
state_first_4.push(builder.mul_add_extension(swap, delta, a)); let diff = builder.sub_extension(input_rhs, input_lhs);
state_next_4.push(builder.arithmetic_extension(F::NEG_ONE, F::ONE, swap, delta, b)); constraints.push(builder.mul_sub_extension(swap, diff, delta_i));
} }
state.extend(state_first_4); // Compute the possibly-swapped input layer.
state.extend(state_next_4); let mut state = [builder.zero_extension(); WIDTH];
for i in 8..SPONGE_WIDTH { for i in 0..4 {
state.push(vars.local_wires[i]); let delta_i = vars.local_wires[Self::wire_delta(i)];
let input_lhs = vars.local_wires[Self::wire_input(i)];
let input_rhs = vars.local_wires[Self::wire_input(i + 4)];
state[i] = builder.add_extension(input_lhs, delta_i);
state[i + 4] = builder.sub_extension(input_rhs, delta_i);
}
for i in 8..WIDTH {
state[i] = vars.local_wires[Self::wire_input(i)];
} }
let mut state: [ExtensionTarget<D>; SPONGE_WIDTH] = state.try_into().unwrap();
let mut round_ctr = 0; let mut round_ctr = 0;
// First set of full rounds. // First set of full rounds.
for r in 0..poseidon::HALF_N_FULL_ROUNDS { for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon>::constant_layer_recursive(builder, &mut state, round_ctr); <F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
for i in 0..SPONGE_WIDTH { if r != 0 {
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
constraints.push(builder.sub_extension(state[i], sbox_in)); constraints.push(builder.sub_extension(state[i], sbox_in));
state[i] = sbox_in; state[i] = sbox_in;
} }
<F as Poseidon>::sbox_layer_recursive(builder, &mut state); }
state = <F as Poseidon>::mds_layer_recursive(builder, &state); <F as Poseidon<WIDTH>>::sbox_layer_recursive(builder, &mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_recursive(builder, &state);
round_ctr += 1; round_ctr += 1;
} }
// Partial rounds. // Partial rounds.
if use_mds_gate { if use_mds_gate {
for r in 0..poseidon::N_PARTIAL_ROUNDS { for r in 0..poseidon::N_PARTIAL_ROUNDS {
<F as Poseidon>::constant_layer_recursive(builder, &mut state, round_ctr); <F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
constraints.push(builder.sub_extension(state[0], sbox_in)); constraints.push(builder.sub_extension(state[0], sbox_in));
state[0] = <F as Poseidon>::sbox_monomial_recursive(builder, sbox_in); state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
state = <F as Poseidon>::mds_layer_recursive(builder, &state); state = <F as Poseidon<WIDTH>>::mds_layer_recursive(builder, &state);
round_ctr += 1; round_ctr += 1;
} }
} else { } else {
<F as Poseidon>::partial_first_constant_layer_recursive(builder, &mut state); <F as Poseidon<WIDTH>>::partial_first_constant_layer_recursive(builder, &mut state);
state = <F as Poseidon>::mds_partial_layer_init_recursive(builder, &mut state); state = <F as Poseidon<WIDTH>>::mds_partial_layer_init_recursive(builder, &state);
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
constraints.push(builder.sub_extension(state[0], sbox_in)); constraints.push(builder.sub_extension(state[0], sbox_in));
state[0] = <F as Poseidon>::sbox_monomial_recursive(builder, sbox_in); state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
state[0] = builder.add_const_extension( let c = <F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r];
state[0], let c = F::Extension::from_canonical_u64(c);
F::from_canonical_u64(<F as Poseidon>::FAST_PARTIAL_ROUND_CONSTANTS[r]), let c = builder.constant_extension(c);
); state[0] = builder.add_extension(state[0], c);
state = <F as Poseidon>::mds_partial_layer_fast_recursive(builder, &state, r); state =
<F as Poseidon<WIDTH>>::mds_partial_layer_fast_recursive(builder, &state, r);
} }
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
constraints.push(builder.sub_extension(state[0], sbox_in)); constraints.push(builder.sub_extension(state[0], sbox_in));
state[0] = <F as Poseidon>::sbox_monomial_recursive(builder, sbox_in); state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
state = <F as Poseidon>::mds_partial_layer_fast_recursive( state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_recursive(
builder, builder,
&state, &state,
poseidon::N_PARTIAL_ROUNDS - 1, poseidon::N_PARTIAL_ROUNDS - 1,
@ -326,18 +361,18 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
// Second set of full rounds. // Second set of full rounds.
for r in 0..poseidon::HALF_N_FULL_ROUNDS { for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon>::constant_layer_recursive(builder, &mut state, round_ctr); <F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
for i in 0..SPONGE_WIDTH { for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)];
constraints.push(builder.sub_extension(state[i], sbox_in)); constraints.push(builder.sub_extension(state[i], sbox_in));
state[i] = sbox_in; state[i] = sbox_in;
} }
<F as Poseidon>::sbox_layer_recursive(builder, &mut state); <F as Poseidon<WIDTH>>::sbox_layer_recursive(builder, &mut state);
state = <F as Poseidon>::mds_layer_recursive(builder, &state); state = <F as Poseidon<WIDTH>>::mds_layer_recursive(builder, &state);
round_ctr += 1; round_ctr += 1;
} }
for i in 0..SPONGE_WIDTH { for i in 0..WIDTH {
constraints constraints
.push(builder.sub_extension(state[i], vars.local_wires[Self::wire_output(i)])); .push(builder.sub_extension(state[i], vars.local_wires[Self::wire_output(i)]));
} }
@ -350,7 +385,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
gate_index: usize, gate_index: usize,
_local_constants: &[F], _local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> { ) -> Vec<Box<dyn WitnessGenerator<F>>> {
let gen = PoseidonGenerator::<F, D> { let gen = PoseidonGenerator::<F, D, WIDTH> {
gate_index, gate_index,
_phantom: PhantomData, _phantom: PhantomData,
}; };
@ -370,23 +405,31 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
} }
fn num_constraints(&self) -> usize { fn num_constraints(&self) -> usize {
SPONGE_WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + SPONGE_WIDTH + 1 WIDTH * (poseidon::N_FULL_ROUNDS_TOTAL - 1) + poseidon::N_PARTIAL_ROUNDS + WIDTH + 1 + 4
} }
} }
#[derive(Debug)] #[derive(Debug)]
struct PoseidonGenerator<F: Extendable<D> + Poseidon, const D: usize> { struct PoseidonGenerator<
F: RichField + Extendable<D> + Poseidon<WIDTH>,
const D: usize,
const WIDTH: usize,
> where
[(); WIDTH - 1]:,
{
gate_index: usize, gate_index: usize,
_phantom: PhantomData<F>, _phantom: PhantomData<F>,
} }
impl<F: RichField + Extendable<D> + Poseidon, const D: usize> SimpleGenerator<F> impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize>
for PoseidonGenerator<F, D> SimpleGenerator<F> for PoseidonGenerator<F, D, WIDTH>
where
[(); WIDTH - 1]:,
{ {
fn dependencies(&self) -> Vec<Target> { fn dependencies(&self) -> Vec<Target> {
(0..SPONGE_WIDTH) (0..WIDTH)
.map(|i| PoseidonGate::<F, D>::wire_input(i)) .map(|i| PoseidonGate::<F, D, WIDTH>::wire_input(i))
.chain(Some(PoseidonGate::<F, D>::WIRE_SWAP)) .chain(Some(PoseidonGate::<F, D, WIDTH>::WIRE_SWAP))
.map(|input| Target::wire(self.gate_index, input)) .map(|input| Target::wire(self.gate_index, input))
.collect() .collect()
} }
@ -397,87 +440,94 @@ impl<F: RichField + Extendable<D> + Poseidon, const D: usize> SimpleGenerator<F>
input, input,
}; };
let mut state = (0..SPONGE_WIDTH) let mut state = (0..WIDTH)
.map(|i| { .map(|i| witness.get_wire(local_wire(PoseidonGate::<F, D, WIDTH>::wire_input(i))))
witness.get_wire(Wire {
gate: self.gate_index,
input: PoseidonGate::<F, D>::wire_input(i),
})
})
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let swap_value = witness.get_wire(Wire { let swap_value = witness.get_wire(local_wire(PoseidonGate::<F, D, WIDTH>::WIRE_SWAP));
gate: self.gate_index,
input: PoseidonGate::<F, D>::WIRE_SWAP,
});
debug_assert!(swap_value == F::ZERO || swap_value == F::ONE); debug_assert!(swap_value == F::ZERO || swap_value == F::ONE);
for i in 0..4 {
let delta_i = swap_value * (state[i + 4] - state[i]);
out_buffer.set_wire(
local_wire(PoseidonGate::<F, D, WIDTH>::wire_delta(i)),
delta_i,
);
}
if swap_value == F::ONE { if swap_value == F::ONE {
for i in 0..4 { for i in 0..4 {
state.swap(i, 4 + i); state.swap(i, 4 + i);
} }
} }
let mut state: [F; SPONGE_WIDTH] = state.try_into().unwrap(); let mut state: [F; WIDTH] = state.try_into().unwrap();
let mut round_ctr = 0; let mut round_ctr = 0;
for r in 0..poseidon::HALF_N_FULL_ROUNDS { for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon>::constant_layer_field(&mut state, round_ctr); <F as Poseidon<WIDTH>>::constant_layer_field(&mut state, round_ctr);
for i in 0..SPONGE_WIDTH { if r != 0 {
for i in 0..WIDTH {
out_buffer.set_wire( out_buffer.set_wire(
local_wire(PoseidonGate::<F, D>::wire_full_sbox_0(r, i)), local_wire(PoseidonGate::<F, D, WIDTH>::wire_full_sbox_0(r, i)),
state[i], state[i],
); );
} }
<F as Poseidon>::sbox_layer_field(&mut state); }
state = <F as Poseidon>::mds_layer_field(&state); <F as Poseidon<WIDTH>>::sbox_layer_field(&mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
round_ctr += 1; round_ctr += 1;
} }
<F as Poseidon>::partial_first_constant_layer(&mut state); <F as Poseidon<WIDTH>>::partial_first_constant_layer(&mut state);
state = <F as Poseidon>::mds_partial_layer_init(&mut state); state = <F as Poseidon<WIDTH>>::mds_partial_layer_init(&state);
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
out_buffer.set_wire( out_buffer.set_wire(
local_wire(PoseidonGate::<F, D>::wire_partial_sbox(r)), local_wire(PoseidonGate::<F, D, WIDTH>::wire_partial_sbox(r)),
state[0], state[0],
); );
state[0] = <F as Poseidon>::sbox_monomial(state[0]); state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(state[0]);
state[0] += F::from_canonical_u64(<F as Poseidon>::FAST_PARTIAL_ROUND_CONSTANTS[r]); state[0] +=
state = <F as Poseidon>::mds_partial_layer_fast_field(&state, r); F::from_canonical_u64(<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r]);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(&state, r);
} }
out_buffer.set_wire( out_buffer.set_wire(
local_wire(PoseidonGate::<F, D>::wire_partial_sbox( local_wire(PoseidonGate::<F, D, WIDTH>::wire_partial_sbox(
poseidon::N_PARTIAL_ROUNDS - 1, poseidon::N_PARTIAL_ROUNDS - 1,
)), )),
state[0], state[0],
); );
state[0] = <F as Poseidon>::sbox_monomial(state[0]); state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(state[0]);
state = state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(
<F as Poseidon>::mds_partial_layer_fast_field(&state, poseidon::N_PARTIAL_ROUNDS - 1); &state,
poseidon::N_PARTIAL_ROUNDS - 1,
);
round_ctr += poseidon::N_PARTIAL_ROUNDS; round_ctr += poseidon::N_PARTIAL_ROUNDS;
for r in 0..poseidon::HALF_N_FULL_ROUNDS { for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon>::constant_layer_field(&mut state, round_ctr); <F as Poseidon<WIDTH>>::constant_layer_field(&mut state, round_ctr);
for i in 0..SPONGE_WIDTH { for i in 0..WIDTH {
out_buffer.set_wire( out_buffer.set_wire(
local_wire(PoseidonGate::<F, D>::wire_full_sbox_1(r, i)), local_wire(PoseidonGate::<F, D, WIDTH>::wire_full_sbox_1(r, i)),
state[i], state[i],
); );
} }
<F as Poseidon>::sbox_layer_field(&mut state); <F as Poseidon<WIDTH>>::sbox_layer_field(&mut state);
state = <F as Poseidon>::mds_layer_field(&state); state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
round_ctr += 1; round_ctr += 1;
} }
for i in 0..SPONGE_WIDTH { for i in 0..WIDTH {
out_buffer.set_wire(local_wire(PoseidonGate::<F, D>::wire_output(i)), state[i]); out_buffer.set_wire(
local_wire(PoseidonGate::<F, D, WIDTH>::wire_output(i)),
state[i],
);
} }
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::convert::TryInto;
use anyhow::Result; use anyhow::Result;
use crate::field::field_types::Field; use crate::field::field_types::Field;
@ -493,6 +543,29 @@ mod tests {
use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
#[test]
fn wire_indices() {
type F = GoldilocksField;
const WIDTH: usize = 12;
type Gate = PoseidonGate<F, 4, WIDTH>;
assert_eq!(Gate::wire_input(0), 0);
assert_eq!(Gate::wire_input(11), 11);
assert_eq!(Gate::wire_output(0), 12);
assert_eq!(Gate::wire_output(11), 23);
assert_eq!(Gate::WIRE_SWAP, 24);
assert_eq!(Gate::wire_delta(0), 25);
assert_eq!(Gate::wire_delta(3), 28);
assert_eq!(Gate::wire_full_sbox_0(1, 0), 29);
assert_eq!(Gate::wire_full_sbox_0(3, 0), 53);
assert_eq!(Gate::wire_full_sbox_0(3, 11), 64);
assert_eq!(Gate::wire_partial_sbox(0), 65);
assert_eq!(Gate::wire_partial_sbox(21), 86);
assert_eq!(Gate::wire_full_sbox_1(0, 0), 87);
assert_eq!(Gate::wire_full_sbox_1(3, 0), 123);
assert_eq!(Gate::wire_full_sbox_1(3, 11), 134);
}
#[test] #[test]
fn generated_output() { fn generated_output() {
const D: usize = 2; const D: usize = 2;

View File

@ -1,4 +1,3 @@
use std::convert::TryInto;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::ops::Range; use std::ops::Range;
@ -6,9 +5,8 @@ use crate::field::extension_field::algebra::ExtensionAlgebra;
use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget};
use crate::field::extension_field::Extendable; use crate::field::extension_field::Extendable;
use crate::field::extension_field::FieldExtension; use crate::field::extension_field::FieldExtension;
use crate::field::field_types::Field; use crate::field::field_types::{Field, RichField};
use crate::gates::gate::Gate; use crate::gates::gate::Gate;
use crate::hash::hashing::SPONGE_WIDTH;
use crate::hash::poseidon::Poseidon; use crate::hash::poseidon::Poseidon;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target; use crate::iop::target::Target;
@ -17,11 +15,21 @@ use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
#[derive(Debug)] #[derive(Debug)]
pub struct PoseidonMdsGate<F: Extendable<D> + Poseidon, const D: usize> { pub struct PoseidonMdsGate<
F: RichField + Extendable<D> + Poseidon<WIDTH>,
const D: usize,
const WIDTH: usize,
> where
[(); WIDTH - 1]:,
{
_phantom: PhantomData<F>, _phantom: PhantomData<F>,
} }
impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> { impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize>
PoseidonMdsGate<F, D, WIDTH>
where
[(); WIDTH - 1]:,
{
pub fn new() -> Self { pub fn new() -> Self {
PoseidonMdsGate { PoseidonMdsGate {
_phantom: PhantomData, _phantom: PhantomData,
@ -29,13 +37,13 @@ impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
} }
pub fn wires_input(i: usize) -> Range<usize> { pub fn wires_input(i: usize) -> Range<usize> {
assert!(i < SPONGE_WIDTH); assert!(i < WIDTH);
i * D..(i + 1) * D i * D..(i + 1) * D
} }
pub fn wires_output(i: usize) -> Range<usize> { pub fn wires_output(i: usize) -> Range<usize> {
assert!(i < SPONGE_WIDTH); assert!(i < WIDTH);
(SPONGE_WIDTH + i) * D..(SPONGE_WIDTH + i + 1) * D (WIDTH + i) * D..(WIDTH + i + 1) * D
} }
// Following are methods analogous to ones in `Poseidon`, but for extension algebras. // Following are methods analogous to ones in `Poseidon`, but for extension algebras.
@ -43,14 +51,15 @@ impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
/// Same as `mds_row_shf` for an extension algebra of `F`. /// Same as `mds_row_shf` for an extension algebra of `F`.
fn mds_row_shf_algebra( fn mds_row_shf_algebra(
r: usize, r: usize,
v: &[ExtensionAlgebra<F::Extension, D>; SPONGE_WIDTH], v: &[ExtensionAlgebra<F::Extension, D>; WIDTH],
) -> ExtensionAlgebra<F::Extension, D> { ) -> ExtensionAlgebra<F::Extension, D> {
debug_assert!(r < SPONGE_WIDTH); debug_assert!(r < WIDTH);
let mut res = ExtensionAlgebra::ZERO; let mut res = ExtensionAlgebra::ZERO;
for i in 0..SPONGE_WIDTH { for i in 0..WIDTH {
let coeff = F::Extension::from_canonical_u64(1 << <F as Poseidon>::MDS_MATRIX_EXPS[i]); let coeff =
res += v[(i + r) % SPONGE_WIDTH].scalar_mul(coeff); F::Extension::from_canonical_u64(1 << <F as Poseidon<WIDTH>>::MDS_MATRIX_EXPS[i]);
res += v[(i + r) % WIDTH].scalar_mul(coeff);
} }
res res
@ -60,16 +69,16 @@ impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
fn mds_row_shf_algebra_recursive( fn mds_row_shf_algebra_recursive(
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
r: usize, r: usize,
v: &[ExtensionAlgebraTarget<D>; SPONGE_WIDTH], v: &[ExtensionAlgebraTarget<D>; WIDTH],
) -> ExtensionAlgebraTarget<D> { ) -> ExtensionAlgebraTarget<D> {
debug_assert!(r < SPONGE_WIDTH); debug_assert!(r < WIDTH);
let mut res = builder.zero_ext_algebra(); let mut res = builder.zero_ext_algebra();
for i in 0..SPONGE_WIDTH { for i in 0..WIDTH {
let coeff = builder.constant_extension(F::Extension::from_canonical_u64( let coeff = builder.constant_extension(F::Extension::from_canonical_u64(
1 << <F as Poseidon>::MDS_MATRIX_EXPS[i], 1 << <F as Poseidon<WIDTH>>::MDS_MATRIX_EXPS[i],
)); ));
res = builder.scalar_mul_add_ext_algebra(coeff, v[(i + r) % SPONGE_WIDTH], res); res = builder.scalar_mul_add_ext_algebra(coeff, v[(i + r) % WIDTH], res);
} }
res res
@ -77,11 +86,11 @@ impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
/// Same as `mds_layer` for an extension algebra of `F`. /// Same as `mds_layer` for an extension algebra of `F`.
fn mds_layer_algebra( fn mds_layer_algebra(
state: &[ExtensionAlgebra<F::Extension, D>; SPONGE_WIDTH], state: &[ExtensionAlgebra<F::Extension, D>; WIDTH],
) -> [ExtensionAlgebra<F::Extension, D>; SPONGE_WIDTH] { ) -> [ExtensionAlgebra<F::Extension, D>; WIDTH] {
let mut result = [ExtensionAlgebra::ZERO; SPONGE_WIDTH]; let mut result = [ExtensionAlgebra::ZERO; WIDTH];
for r in 0..SPONGE_WIDTH { for r in 0..WIDTH {
result[r] = Self::mds_row_shf_algebra(r, state); result[r] = Self::mds_row_shf_algebra(r, state);
} }
@ -91,11 +100,11 @@ impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
/// Same as `mds_layer_recursive` for an extension algebra of `F`. /// Same as `mds_layer_recursive` for an extension algebra of `F`.
fn mds_layer_algebra_recursive( fn mds_layer_algebra_recursive(
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
state: &[ExtensionAlgebraTarget<D>; SPONGE_WIDTH], state: &[ExtensionAlgebraTarget<D>; WIDTH],
) -> [ExtensionAlgebraTarget<D>; SPONGE_WIDTH] { ) -> [ExtensionAlgebraTarget<D>; WIDTH] {
let mut result = [builder.zero_ext_algebra(); SPONGE_WIDTH]; let mut result = [builder.zero_ext_algebra(); WIDTH];
for r in 0..SPONGE_WIDTH { for r in 0..WIDTH {
result[r] = Self::mds_row_shf_algebra_recursive(builder, r, state); result[r] = Self::mds_row_shf_algebra_recursive(builder, r, state);
} }
@ -103,13 +112,17 @@ impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
} }
} }
impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate<F, D> { impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize> Gate<F, D>
for PoseidonMdsGate<F, D, WIDTH>
where
[(); WIDTH - 1]:,
{
fn id(&self) -> String { fn id(&self) -> String {
format!("{:?}<WIDTH={}>", self, SPONGE_WIDTH) format!("{:?}<WIDTH={}>", self, WIDTH)
} }
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> { fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH) let inputs: [_; WIDTH] = (0..WIDTH)
.map(|i| vars.get_local_ext_algebra(Self::wires_input(i))) .map(|i| vars.get_local_ext_algebra(Self::wires_input(i)))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.try_into() .try_into()
@ -117,7 +130,7 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
let computed_outputs = Self::mds_layer_algebra(&inputs); let computed_outputs = Self::mds_layer_algebra(&inputs);
(0..SPONGE_WIDTH) (0..WIDTH)
.map(|i| vars.get_local_ext_algebra(Self::wires_output(i))) .map(|i| vars.get_local_ext_algebra(Self::wires_output(i)))
.zip(computed_outputs) .zip(computed_outputs)
.flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array()) .flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array())
@ -125,7 +138,7 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
} }
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> { fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH) let inputs: [_; WIDTH] = (0..WIDTH)
.map(|i| vars.get_local_ext(Self::wires_input(i))) .map(|i| vars.get_local_ext(Self::wires_input(i)))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.try_into() .try_into()
@ -133,7 +146,7 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
let computed_outputs = F::mds_layer_field(&inputs); let computed_outputs = F::mds_layer_field(&inputs);
(0..SPONGE_WIDTH) (0..WIDTH)
.map(|i| vars.get_local_ext(Self::wires_output(i))) .map(|i| vars.get_local_ext(Self::wires_output(i)))
.zip(computed_outputs) .zip(computed_outputs)
.flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array()) .flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array())
@ -145,7 +158,7 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>, vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> { ) -> Vec<ExtensionTarget<D>> {
let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH) let inputs: [_; WIDTH] = (0..WIDTH)
.map(|i| vars.get_local_ext_algebra(Self::wires_input(i))) .map(|i| vars.get_local_ext_algebra(Self::wires_input(i)))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.try_into() .try_into()
@ -153,7 +166,7 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
let computed_outputs = Self::mds_layer_algebra_recursive(builder, &inputs); let computed_outputs = Self::mds_layer_algebra_recursive(builder, &inputs);
(0..SPONGE_WIDTH) (0..WIDTH)
.map(|i| vars.get_local_ext_algebra(Self::wires_output(i))) .map(|i| vars.get_local_ext_algebra(Self::wires_output(i)))
.zip(computed_outputs) .zip(computed_outputs)
.flat_map(|(out, computed_out)| { .flat_map(|(out, computed_out)| {
@ -169,12 +182,12 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
gate_index: usize, gate_index: usize,
_local_constants: &[F], _local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> { ) -> Vec<Box<dyn WitnessGenerator<F>>> {
let gen = PoseidonMdsGenerator::<D> { gate_index }; let gen = PoseidonMdsGenerator::<D, WIDTH> { gate_index };
vec![Box::new(gen.adapter())] vec![Box::new(gen.adapter())]
} }
fn num_wires(&self) -> usize { fn num_wires(&self) -> usize {
2 * D * SPONGE_WIDTH 2 * D * WIDTH
} }
fn num_constants(&self) -> usize { fn num_constants(&self) -> usize {
@ -186,20 +199,30 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
} }
fn num_constraints(&self) -> usize { fn num_constraints(&self) -> usize {
SPONGE_WIDTH * D WIDTH * D
} }
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
struct PoseidonMdsGenerator<const D: usize> { struct PoseidonMdsGenerator<const D: usize, const WIDTH: usize>
where
[(); WIDTH - 1]:,
{
gate_index: usize, gate_index: usize,
} }
impl<F: Extendable<D> + Poseidon, const D: usize> SimpleGenerator<F> for PoseidonMdsGenerator<D> { impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize>
SimpleGenerator<F> for PoseidonMdsGenerator<D, WIDTH>
where
[(); WIDTH - 1]:,
{
fn dependencies(&self) -> Vec<Target> { fn dependencies(&self) -> Vec<Target> {
(0..SPONGE_WIDTH) (0..WIDTH)
.flat_map(|i| { .flat_map(|i| {
Target::wires_from_range(self.gate_index, PoseidonMdsGate::<F, D>::wires_input(i)) Target::wires_from_range(
self.gate_index,
PoseidonMdsGate::<F, D, WIDTH>::wires_input(i),
)
}) })
.collect() .collect()
} }
@ -210,8 +233,8 @@ impl<F: Extendable<D> + Poseidon, const D: usize> SimpleGenerator<F> for Poseido
let get_local_ext = let get_local_ext =
|wire_range| witness.get_extension_target(get_local_get_target(wire_range)); |wire_range| witness.get_extension_target(get_local_get_target(wire_range));
let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH) let inputs: [_; WIDTH] = (0..WIDTH)
.map(|i| get_local_ext(PoseidonMdsGate::<F, D>::wires_input(i))) .map(|i| get_local_ext(PoseidonMdsGate::<F, D, WIDTH>::wires_input(i)))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.try_into() .try_into()
.unwrap(); .unwrap();
@ -220,7 +243,7 @@ impl<F: Extendable<D> + Poseidon, const D: usize> SimpleGenerator<F> for Poseido
for (i, &out) in outputs.iter().enumerate() { for (i, &out) in outputs.iter().enumerate() {
out_buffer.set_extension_target( out_buffer.set_extension_target(
get_local_get_target(PoseidonMdsGate::<F, D>::wires_output(i)), get_local_get_target(PoseidonMdsGate::<F, D, WIDTH>::wires_output(i)),
out, out,
); );
} }
@ -232,21 +255,19 @@ mod tests {
use crate::field::goldilocks_field::GoldilocksField; use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::gates::poseidon_mds::PoseidonMdsGate; use crate::gates::poseidon_mds::PoseidonMdsGate;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::hash::hashing::SPONGE_WIDTH;
#[test] #[test]
fn low_degree() { fn low_degree() {
type F = GoldilocksField; type F = GoldilocksField;
let gate = PoseidonMdsGate::<F, 4>::new(); let gate = PoseidonMdsGate::<F, 4, SPONGE_WIDTH>::new();
test_low_degree(gate) test_low_degree(gate)
} }
#[test] #[test]
fn eval_fns() -> anyhow::Result<()> { fn eval_fns() -> anyhow::Result<()> {
const D: usize = 2; type F = GoldilocksField;
type C = PoseidonGoldilocksConfig; let gate = PoseidonMdsGate::<F, 4, SPONGE_WIDTH>::new();
type F = <C as GenericConfig<D>>::F; test_eval_fns(gate)
let gate = PoseidonMdsGate::<F, D>::new();
test_eval_fns::<F, C, _, D>(gate)
} }
} }

View File

@ -1,5 +1,7 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use itertools::Itertools;
use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable; use crate::field::extension_field::Extendable;
use crate::field::field_types::Field; use crate::field::field_types::Field;
@ -15,75 +17,64 @@ use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// A gate for checking that a particular element of a list matches a given value. /// A gate for checking that a particular element of a list matches a given value.
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
pub(crate) struct RandomAccessGate<F: Extendable<D>, const D: usize> { pub(crate) struct RandomAccessGate<F: Extendable<D>, const D: usize> {
pub vec_size: usize, pub bits: usize,
pub num_copies: usize, pub num_copies: usize,
_phantom: PhantomData<F>, _phantom: PhantomData<F>,
} }
impl<F: Extendable<D>, const D: usize> RandomAccessGate<F, D> { impl<F: Extendable<D>, const D: usize> RandomAccessGate<F, D> {
pub fn new(num_copies: usize, vec_size: usize) -> Self { fn new(num_copies: usize, bits: usize) -> Self {
Self { Self {
vec_size, bits,
num_copies, num_copies,
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
pub fn new_from_config(config: &CircuitConfig, vec_size: usize) -> Self { pub fn new_from_config(config: &CircuitConfig, bits: usize) -> Self {
let num_copies = Self::max_num_copies(config.num_routed_wires, config.num_wires, vec_size); let vec_size = 1 << bits;
Self::new(num_copies, vec_size) // Need `(2 + vec_size) * num_copies` routed wires
let max_copies = (config.num_routed_wires / (2 + vec_size)).min(
// Need `(2 + vec_size + bits) * num_copies` wires
config.num_wires / (2 + vec_size + bits),
);
Self::new(max_copies, bits)
} }
pub fn max_num_copies(num_routed_wires: usize, num_wires: usize, vec_size: usize) -> usize { fn vec_size(&self) -> usize {
// Need `(2 + vec_size) * num_copies` routed wires 1 << self.bits
(num_routed_wires / (2 + vec_size)).min(
// Need `(2 + 3*vec_size) * num_copies` wires
num_wires / (2 + 3 * vec_size),
)
} }
pub fn wire_access_index(&self, copy: usize) -> usize { pub fn wire_access_index(&self, copy: usize) -> usize {
debug_assert!(copy < self.num_copies); debug_assert!(copy < self.num_copies);
(2 + self.vec_size) * copy (2 + self.vec_size()) * copy
} }
pub fn wire_claimed_element(&self, copy: usize) -> usize { pub fn wire_claimed_element(&self, copy: usize) -> usize {
debug_assert!(copy < self.num_copies); debug_assert!(copy < self.num_copies);
(2 + self.vec_size) * copy + 1 (2 + self.vec_size()) * copy + 1
} }
pub fn wire_list_item(&self, i: usize, copy: usize) -> usize { pub fn wire_list_item(&self, i: usize, copy: usize) -> usize {
debug_assert!(i < self.vec_size); debug_assert!(i < self.vec_size());
debug_assert!(copy < self.num_copies); debug_assert!(copy < self.num_copies);
(2 + self.vec_size) * copy + 2 + i (2 + self.vec_size()) * copy + 2 + i
} }
fn start_of_intermediate_wires(&self) -> usize { fn start_of_intermediate_wires(&self) -> usize {
(2 + self.vec_size) * self.num_copies (2 + self.vec_size()) * self.num_copies
} }
pub(crate) fn num_routed_wires(&self) -> usize { pub(crate) fn num_routed_wires(&self) -> usize {
self.start_of_intermediate_wires() self.start_of_intermediate_wires()
} }
/// An intermediate wire for a dummy variable used to show equality. /// An intermediate wire where the prover gives the (purported) binary decomposition of the
/// The prover sets this to 1/(x-y) if x != y, or to an arbitrary value if /// index.
/// x == y. pub fn wire_bit(&self, i: usize, copy: usize) -> usize {
pub fn wire_equality_dummy_for_index(&self, i: usize, copy: usize) -> usize { debug_assert!(i < self.bits);
debug_assert!(i < self.vec_size);
debug_assert!(copy < self.num_copies); debug_assert!(copy < self.num_copies);
self.start_of_intermediate_wires() + copy * self.vec_size + i self.start_of_intermediate_wires() + copy * self.bits + i
}
/// An intermediate wire for the "index_matches" variable (1 if the current index is the index at
/// which to compare, 0 otherwise).
pub fn wire_index_matches_for_index(&self, i: usize, copy: usize) -> usize {
debug_assert!(i < self.vec_size);
debug_assert!(copy < self.num_copies);
self.start_of_intermediate_wires()
+ self.vec_size * self.num_copies
+ self.vec_size * copy
+ i
} }
} }
@ -97,23 +88,38 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
for copy in 0..self.num_copies { for copy in 0..self.num_copies {
let access_index = vars.local_wires[self.wire_access_index(copy)]; let access_index = vars.local_wires[self.wire_access_index(copy)];
let list_items = (0..self.vec_size) let mut list_items = (0..self.vec_size())
.map(|i| vars.local_wires[self.wire_list_item(i, copy)]) .map(|i| vars.local_wires[self.wire_list_item(i, copy)])
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let claimed_element = vars.local_wires[self.wire_claimed_element(copy)]; let claimed_element = vars.local_wires[self.wire_claimed_element(copy)];
let bits = (0..self.bits)
.map(|i| vars.local_wires[self.wire_bit(i, copy)])
.collect::<Vec<_>>();
for i in 0..self.vec_size { // Assert that each bit wire value is indeed boolean.
let cur_index = F::Extension::from_canonical_usize(i); for &b in &bits {
let difference = cur_index - access_index; constraints.push(b * (b - F::Extension::ONE));
let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i, copy)];
let index_matches = vars.local_wires[self.wire_index_matches_for_index(i, copy)];
// The two index equality constraints.
constraints.push(difference * equality_dummy - (F::Extension::ONE - index_matches));
constraints.push(index_matches * difference);
// Value equality constraint.
constraints.push((list_items[i] - claimed_element) * index_matches);
} }
// Assert that the binary decomposition was correct.
let reconstructed_index = bits
.iter()
.rev()
.fold(F::Extension::ZERO, |acc, &b| acc.double() + b);
constraints.push(reconstructed_index - access_index);
// Repeatedly fold the list, selecting the left or right item from each pair based on
// the corresponding bit.
for b in bits {
list_items = list_items
.iter()
.tuples()
.map(|(&x, &y)| x + b * (y - x))
.collect()
}
debug_assert_eq!(list_items.len(), 1);
constraints.push(list_items[0] - claimed_element);
} }
constraints constraints
@ -124,23 +130,35 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
for copy in 0..self.num_copies { for copy in 0..self.num_copies {
let access_index = vars.local_wires[self.wire_access_index(copy)]; let access_index = vars.local_wires[self.wire_access_index(copy)];
let list_items = (0..self.vec_size) let mut list_items = (0..self.vec_size())
.map(|i| vars.local_wires[self.wire_list_item(i, copy)]) .map(|i| vars.local_wires[self.wire_list_item(i, copy)])
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let claimed_element = vars.local_wires[self.wire_claimed_element(copy)]; let claimed_element = vars.local_wires[self.wire_claimed_element(copy)];
let bits = (0..self.bits)
.map(|i| vars.local_wires[self.wire_bit(i, copy)])
.collect::<Vec<_>>();
for i in 0..self.vec_size { // Assert that each bit wire value is indeed boolean.
let cur_index = F::from_canonical_usize(i); for &b in &bits {
let difference = cur_index - access_index; constraints.push(b * (b - F::ONE));
let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i, copy)];
let index_matches = vars.local_wires[self.wire_index_matches_for_index(i, copy)];
// The two index equality constraints.
constraints.push(difference * equality_dummy - (F::ONE - index_matches));
constraints.push(index_matches * difference);
// Value equality constraint.
constraints.push((list_items[i] - claimed_element) * index_matches);
} }
// Assert that the binary decomposition was correct.
let reconstructed_index = bits.iter().rev().fold(F::ZERO, |acc, &b| acc.double() + b);
constraints.push(reconstructed_index - access_index);
// Repeatedly fold the list, selecting the left or right item from each pair based on
// the corresponding bit.
for b in bits {
list_items = list_items
.iter()
.tuples()
.map(|(&x, &y)| x + b * (y - x))
.collect()
}
debug_assert_eq!(list_items.len(), 1);
constraints.push(list_items[0] - claimed_element);
} }
constraints constraints
@ -151,36 +169,44 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>, vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> { ) -> Vec<ExtensionTarget<D>> {
let zero = builder.zero_extension();
let two = builder.two_extension();
let mut constraints = Vec::with_capacity(self.num_constraints()); let mut constraints = Vec::with_capacity(self.num_constraints());
for copy in 0..self.num_copies { for copy in 0..self.num_copies {
let access_index = vars.local_wires[self.wire_access_index(copy)]; let access_index = vars.local_wires[self.wire_access_index(copy)];
let list_items = (0..self.vec_size) let mut list_items = (0..self.vec_size())
.map(|i| vars.local_wires[self.wire_list_item(i, copy)]) .map(|i| vars.local_wires[self.wire_list_item(i, copy)])
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let claimed_element = vars.local_wires[self.wire_claimed_element(copy)]; let claimed_element = vars.local_wires[self.wire_claimed_element(copy)];
let bits = (0..self.bits)
.map(|i| vars.local_wires[self.wire_bit(i, copy)])
.collect::<Vec<_>>();
for i in 0..self.vec_size { // Assert that each bit wire value is indeed boolean.
let cur_index_ext = F::Extension::from_canonical_usize(i); for &b in &bits {
let cur_index = builder.constant_extension(cur_index_ext); constraints.push(builder.mul_sub_extension(b, b, b));
let difference = builder.sub_extension(cur_index, access_index);
let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i, copy)];
let index_matches = vars.local_wires[self.wire_index_matches_for_index(i, copy)];
let one = builder.one_extension();
let not_index_matches = builder.sub_extension(one, index_matches);
let first_equality_constraint =
builder.mul_sub_extension(difference, equality_dummy, not_index_matches);
constraints.push(first_equality_constraint);
let second_equality_constraint = builder.mul_extension(index_matches, difference);
constraints.push(second_equality_constraint);
// Output constraint.
let diff = builder.sub_extension(list_items[i], claimed_element);
let conditional_diff = builder.mul_extension(index_matches, diff);
constraints.push(conditional_diff);
} }
// Assert that the binary decomposition was correct.
let reconstructed_index = bits
.iter()
.rev()
.fold(zero, |acc, &b| builder.mul_add_extension(acc, two, b));
constraints.push(builder.sub_extension(reconstructed_index, access_index));
// Repeatedly fold the list, selecting the left or right item from each pair based on
// the corresponding bit.
for b in bits {
list_items = list_items
.iter()
.tuples()
.map(|(&x, &y)| builder.select_ext_generalized(b, y, x))
.collect()
}
debug_assert_eq!(list_items.len(), 1);
constraints.push(builder.sub_extension(list_items[0], claimed_element));
} }
constraints constraints
@ -207,7 +233,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
} }
fn num_wires(&self) -> usize { fn num_wires(&self) -> usize {
self.wire_index_matches_for_index(self.vec_size - 1, self.num_copies - 1) + 1 self.wire_bit(self.bits - 1, self.num_copies - 1) + 1
} }
fn num_constants(&self) -> usize { fn num_constants(&self) -> usize {
@ -215,11 +241,12 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
} }
fn degree(&self) -> usize { fn degree(&self) -> usize {
2 self.bits + 1
} }
fn num_constraints(&self) -> usize { fn num_constraints(&self) -> usize {
3 * self.num_copies * self.vec_size let constraints_per_copy = self.bits + 2;
self.num_copies * constraints_per_copy
} }
} }
@ -234,10 +261,8 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for RandomAccessGenera
fn dependencies(&self) -> Vec<Target> { fn dependencies(&self) -> Vec<Target> {
let local_target = |input| Target::wire(self.gate_index, input); let local_target = |input| Target::wire(self.gate_index, input);
let mut deps = Vec::new(); let mut deps = vec![local_target(self.gate.wire_access_index(self.copy))];
deps.push(local_target(self.gate.wire_access_index(self.copy))); for i in 0..self.gate.vec_size() {
deps.push(local_target(self.gate.wire_claimed_element(self.copy)));
for i in 0..self.gate.vec_size {
deps.push(local_target(self.gate.wire_list_item(i, self.copy))); deps.push(local_target(self.gate.wire_list_item(i, self.copy)));
} }
deps deps
@ -250,11 +275,12 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for RandomAccessGenera
}; };
let get_local_wire = |input| witness.get_wire(local_wire(input)); let get_local_wire = |input| witness.get_wire(local_wire(input));
let mut set_local_wire = |input, value| out_buffer.set_wire(local_wire(input), value);
// Compute the new vector and the values for equality_dummy and index_matches let copy = self.copy;
let vec_size = self.gate.vec_size; let vec_size = self.gate.vec_size();
let access_index_f = get_local_wire(self.gate.wire_access_index(self.copy));
let access_index_f = get_local_wire(self.gate.wire_access_index(copy));
let access_index = access_index_f.to_canonical_u64() as usize; let access_index = access_index_f.to_canonical_u64() as usize;
debug_assert!( debug_assert!(
access_index < vec_size, access_index < vec_size,
@ -263,22 +289,14 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for RandomAccessGenera
vec_size vec_size
); );
for i in 0..vec_size { set_local_wire(
let equality_dummy_wire = self.gate.wire_claimed_element(copy),
local_wire(self.gate.wire_equality_dummy_for_index(i, self.copy)); get_local_wire(self.gate.wire_list_item(access_index, copy)),
let index_matches_wire =
local_wire(self.gate.wire_index_matches_for_index(i, self.copy));
if i == access_index {
out_buffer.set_wire(equality_dummy_wire, F::ONE);
out_buffer.set_wire(index_matches_wire, F::ONE);
} else {
out_buffer.set_wire(
equality_dummy_wire,
(F::from_canonical_usize(i) - F::from_canonical_usize(access_index)).inverse(),
); );
out_buffer.set_wire(index_matches_wire, F::ZERO);
} for i in 0..self.gate.bits {
let bit = F::from_bool(((access_index >> i) & 1) != 0);
set_local_wire(self.gate.wire_bit(i, copy), bit);
} }
} }
} }
@ -322,6 +340,7 @@ mod tests {
/// Returns the local wires for a random access gate given the vectors, elements to compare, /// Returns the local wires for a random access gate given the vectors, elements to compare,
/// and indices. /// and indices.
fn get_wires( fn get_wires(
bits: usize,
lists: Vec<Vec<F>>, lists: Vec<Vec<F>>,
access_indices: Vec<usize>, access_indices: Vec<usize>,
claimed_elements: Vec<F>, claimed_elements: Vec<F>,
@ -330,8 +349,7 @@ mod tests {
let vec_size = lists[0].len(); let vec_size = lists[0].len();
let mut v = Vec::new(); let mut v = Vec::new();
let mut equality_dummy_vals = Vec::new(); let mut bit_vals = Vec::new();
let mut index_matches_vals = Vec::new();
for copy in 0..num_copies { for copy in 0..num_copies {
let access_index = access_indices[copy]; let access_index = access_indices[copy];
v.push(F::from_canonical_usize(access_index)); v.push(F::from_canonical_usize(access_index));
@ -340,26 +358,17 @@ mod tests {
v.push(lists[copy][j]); v.push(lists[copy][j]);
} }
for i in 0..vec_size { for i in 0..bits {
if i == access_index { bit_vals.push(F::from_bool(((access_index >> i) & 1) != 0));
equality_dummy_vals.push(F::ONE);
index_matches_vals.push(F::ONE);
} else {
equality_dummy_vals.push(
(F::from_canonical_usize(i) - F::from_canonical_usize(access_index))
.inverse(),
);
index_matches_vals.push(F::ZERO);
} }
} }
} v.extend(bit_vals);
v.extend(equality_dummy_vals);
v.extend(index_matches_vals);
v.iter().map(|&x| x.into()).collect::<Vec<_>>() v.iter().map(|&x| x.into()).collect()
} }
let vec_size = 3; let bits = 3;
let vec_size = 1 << bits;
let num_copies = 4; let num_copies = 4;
let lists = (0..num_copies) let lists = (0..num_copies)
.map(|_| F::rand_vec(vec_size)) .map(|_| F::rand_vec(vec_size))
@ -368,7 +377,7 @@ mod tests {
.map(|_| thread_rng().gen_range(0..vec_size)) .map(|_| thread_rng().gen_range(0..vec_size))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let gate = RandomAccessGate::<F, D> { let gate = RandomAccessGate::<F, D> {
vec_size, bits,
num_copies, num_copies,
_phantom: PhantomData, _phantom: PhantomData,
}; };
@ -380,13 +389,18 @@ mod tests {
.collect(); .collect();
let good_vars = EvaluationVars { let good_vars = EvaluationVars {
local_constants: &[], local_constants: &[],
local_wires: &get_wires(lists.clone(), access_indices.clone(), good_claimed_elements), local_wires: &get_wires(
bits,
lists.clone(),
access_indices.clone(),
good_claimed_elements,
),
public_inputs_hash: &HashOut::rand(), public_inputs_hash: &HashOut::rand(),
}; };
let bad_claimed_elements = F::rand_vec(4); let bad_claimed_elements = F::rand_vec(4);
let bad_vars = EvaluationVars { let bad_vars = EvaluationVars {
local_constants: &[], local_constants: &[],
local_wires: &get_wires(lists, access_indices, bad_claimed_elements), local_wires: &get_wires(bits, lists, access_indices, bad_claimed_elements),
public_inputs_hash: &HashOut::rand(), public_inputs_hash: &HashOut::rand(),
}; };

View File

@ -0,0 +1,222 @@
use std::ops::Range;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::extension_field::FieldExtension;
use crate::field::field_types::RichField;
use crate::gates::gate::Gate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// Computes `sum alpha^i c_i` for a vector `c_i` of `num_coeffs` elements of the extension field.
#[derive(Debug, Clone)]
pub struct ReducingExtensionGate<const D: usize> {
pub num_coeffs: usize,
}
impl<const D: usize> ReducingExtensionGate<D> {
pub fn new(num_coeffs: usize) -> Self {
Self { num_coeffs }
}
pub fn max_coeffs_len(num_wires: usize, num_routed_wires: usize) -> usize {
// `3*D` routed wires are used for the output, alpha and old accumulator.
// Need `num_coeffs*D` routed wires for coeffs, and `(num_coeffs-1)*D` wires for accumulators.
((num_routed_wires - 3 * D) / D).min((num_wires - 2 * D) / (D * 2))
}
pub fn wires_output() -> Range<usize> {
0..D
}
pub fn wires_alpha() -> Range<usize> {
D..2 * D
}
pub fn wires_old_acc() -> Range<usize> {
2 * D..3 * D
}
const START_COEFFS: usize = 3 * D;
pub fn wires_coeff(i: usize) -> Range<usize> {
Self::START_COEFFS + i * D..Self::START_COEFFS + (i + 1) * D
}
fn start_accs(&self) -> usize {
Self::START_COEFFS + self.num_coeffs * D
}
fn wires_accs(&self, i: usize) -> Range<usize> {
debug_assert!(i < self.num_coeffs);
if i == self.num_coeffs - 1 {
// The last accumulator is the output.
return Self::wires_output();
}
self.start_accs() + D * i..self.start_accs() + D * (i + 1)
}
}
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ReducingExtensionGate<D> {
fn id(&self) -> String {
format!("{:?}", self)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let alpha = vars.get_local_ext_algebra(Self::wires_alpha());
let old_acc = vars.get_local_ext_algebra(Self::wires_old_acc());
let coeffs = (0..self.num_coeffs)
.map(|i| vars.get_local_ext_algebra(Self::wires_coeff(i)))
.collect::<Vec<_>>();
let accs = (0..self.num_coeffs)
.map(|i| vars.get_local_ext_algebra(self.wires_accs(i)))
.collect::<Vec<_>>();
let mut constraints = Vec::with_capacity(<Self as Gate<F, D>>::num_constraints(self));
let mut acc = old_acc;
for i in 0..self.num_coeffs {
constraints.push(acc * alpha + coeffs[i] - accs[i]);
acc = accs[i];
}
constraints
.into_iter()
.flat_map(|alg| alg.to_basefield_array())
.collect()
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let alpha = vars.get_local_ext(Self::wires_alpha());
let old_acc = vars.get_local_ext(Self::wires_old_acc());
let coeffs = (0..self.num_coeffs)
.map(|i| vars.get_local_ext(Self::wires_coeff(i)))
.collect::<Vec<_>>();
let accs = (0..self.num_coeffs)
.map(|i| vars.get_local_ext(self.wires_accs(i)))
.collect::<Vec<_>>();
let mut constraints = Vec::with_capacity(<Self as Gate<F, D>>::num_constraints(self));
let mut acc = old_acc;
for i in 0..self.num_coeffs {
constraints.extend((acc * alpha + coeffs[i] - accs[i]).to_basefield_array());
acc = accs[i];
}
constraints
}
fn eval_unfiltered_recursively(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let alpha = vars.get_local_ext_algebra(Self::wires_alpha());
let old_acc = vars.get_local_ext_algebra(Self::wires_old_acc());
let coeffs = (0..self.num_coeffs)
.map(|i| vars.get_local_ext_algebra(Self::wires_coeff(i)))
.collect::<Vec<_>>();
let accs = (0..self.num_coeffs)
.map(|i| vars.get_local_ext_algebra(self.wires_accs(i)))
.collect::<Vec<_>>();
let mut constraints = Vec::with_capacity(<Self as Gate<F, D>>::num_constraints(self));
let mut acc = old_acc;
for i in 0..self.num_coeffs {
let coeff = coeffs[i];
let mut tmp = builder.mul_add_ext_algebra(acc, alpha, coeff);
tmp = builder.sub_ext_algebra(tmp, accs[i]);
constraints.push(tmp);
acc = accs[i];
}
constraints
.into_iter()
.flat_map(|alg| alg.to_ext_target_array())
.collect()
}
fn generators(
&self,
gate_index: usize,
_local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
vec![Box::new(
ReducingGenerator {
gate_index,
gate: self.clone(),
}
.adapter(),
)]
}
fn num_wires(&self) -> usize {
2 * D + 2 * D * self.num_coeffs
}
fn num_constants(&self) -> usize {
0
}
fn degree(&self) -> usize {
2
}
fn num_constraints(&self) -> usize {
D * self.num_coeffs
}
}
#[derive(Debug)]
struct ReducingGenerator<const D: usize> {
gate_index: usize,
gate: ReducingExtensionGate<D>,
}
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ReducingGenerator<D> {
fn dependencies(&self) -> Vec<Target> {
ReducingExtensionGate::<D>::wires_alpha()
.chain(ReducingExtensionGate::<D>::wires_old_acc())
.chain((0..self.gate.num_coeffs).flat_map(ReducingExtensionGate::<D>::wires_coeff))
.map(|i| Target::wire(self.gate_index, i))
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let local_extension = |range: Range<usize>| -> F::Extension {
let t = ExtensionTarget::from_range(self.gate_index, range);
witness.get_extension_target(t)
};
let alpha = local_extension(ReducingExtensionGate::<D>::wires_alpha());
let old_acc = local_extension(ReducingExtensionGate::<D>::wires_old_acc());
let coeffs = (0..self.gate.num_coeffs)
.map(|i| local_extension(ReducingExtensionGate::<D>::wires_coeff(i)))
.collect::<Vec<_>>();
let accs = (0..self.gate.num_coeffs)
.map(|i| ExtensionTarget::from_range(self.gate_index, self.gate.wires_accs(i)))
.collect::<Vec<_>>();
let mut acc = old_acc;
for i in 0..self.gate.num_coeffs {
let computed_acc = acc * alpha + coeffs[i];
out_buffer.set_extension_target(accs[i], computed_acc);
acc = computed_acc;
}
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::gates::reducing_extension::ReducingExtensionGate;
#[test]
fn low_degree() {
test_low_degree::<GoldilocksField, _, 4>(ReducingExtensionGate::new(22));
}
#[test]
fn eval_fns() -> Result<()> {
test_eval_fns::<GoldilocksField, _, 4>(ReducingExtensionGate::new(22))
}
}

View File

@ -0,0 +1,423 @@
use std::marker::PhantomData;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, RichField};
use crate::gates::gate::Gate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// A gate to perform a subtraction on 32-bit limbs: given `x`, `y`, and `borrow`, it returns
/// the result `x - y - borrow` and, if this underflows, a new `borrow`. Inputs are not range-checked.
#[derive(Copy, Clone, Debug)]
pub struct U32SubtractionGate<F: RichField + Extendable<D>, const D: usize> {
pub num_ops: usize,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize> U32SubtractionGate<F, D> {
pub fn new_from_config(config: &CircuitConfig) -> Self {
Self {
num_ops: Self::num_ops(config),
_phantom: PhantomData,
}
}
pub(crate) fn num_ops(config: &CircuitConfig) -> usize {
let wires_per_op = 5 + Self::num_limbs();
let routed_wires_per_op = 5;
(config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op)
}
pub fn wire_ith_input_x(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i
}
pub fn wire_ith_input_y(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i + 1
}
pub fn wire_ith_input_borrow(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i + 2
}
pub fn wire_ith_output_result(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i + 3
}
pub fn wire_ith_output_borrow(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i + 4
}
pub fn limb_bits() -> usize {
2
}
// We have limbs for the 32 bits of `output_result`.
pub fn num_limbs() -> usize {
32 / Self::limb_bits()
}
pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize {
debug_assert!(i < self.num_ops);
debug_assert!(j < Self::num_limbs());
5 * self.num_ops + Self::num_limbs() * i + j
}
}
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32SubtractionGate<F, D> {
fn id(&self) -> String {
format!("{:?}", self)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let mut constraints = Vec::with_capacity(self.num_constraints());
for i in 0..self.num_ops {
let input_x = vars.local_wires[self.wire_ith_input_x(i)];
let input_y = vars.local_wires[self.wire_ith_input_y(i)];
let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)];
let result_initial = input_x - input_y - input_borrow;
let base = F::Extension::from_canonical_u64(1 << 32u64);
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)];
constraints.push(output_result - (result_initial + base * output_borrow));
// Range-check output_result to be at most 32 bits.
let mut combined_limbs = F::Extension::ZERO;
let limb_base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits());
for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits();
let product = (0..max_limb)
.map(|x| this_limb - F::Extension::from_canonical_usize(x))
.product();
constraints.push(product);
combined_limbs = limb_base * combined_limbs + this_limb;
}
constraints.push(combined_limbs - output_result);
// Range-check output_borrow to be one bit.
constraints.push(output_borrow * (F::Extension::ONE - output_borrow));
}
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints());
for i in 0..self.num_ops {
let input_x = vars.local_wires[self.wire_ith_input_x(i)];
let input_y = vars.local_wires[self.wire_ith_input_y(i)];
let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)];
let result_initial = input_x - input_y - input_borrow;
let base = F::from_canonical_u64(1 << 32u64);
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)];
constraints.push(output_result - (result_initial + base * output_borrow));
// Range-check output_result to be at most 32 bits.
let mut combined_limbs = F::ZERO;
let limb_base = F::from_canonical_u64(1u64 << Self::limb_bits());
for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits();
let product = (0..max_limb)
.map(|x| this_limb - F::from_canonical_usize(x))
.product();
constraints.push(product);
combined_limbs = limb_base * combined_limbs + this_limb;
}
constraints.push(combined_limbs - output_result);
// Range-check output_borrow to be one bit.
constraints.push(output_borrow * (F::ONE - output_borrow));
}
constraints
}
fn eval_unfiltered_recursively(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let mut constraints = Vec::with_capacity(self.num_constraints());
for i in 0..self.num_ops {
let input_x = vars.local_wires[self.wire_ith_input_x(i)];
let input_y = vars.local_wires[self.wire_ith_input_y(i)];
let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)];
let diff = builder.sub_extension(input_x, input_y);
let result_initial = builder.sub_extension(diff, input_borrow);
let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << 32u64));
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)];
let computed_output = builder.mul_add_extension(base, output_borrow, result_initial);
constraints.push(builder.sub_extension(output_result, computed_output));
// Range-check output_result to be at most 32 bits.
let mut combined_limbs = builder.zero_extension();
let limb_base = builder
.constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits()));
for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits();
let mut product = builder.one_extension();
for x in 0..max_limb {
let x_target =
builder.constant_extension(F::Extension::from_canonical_usize(x));
let diff = builder.sub_extension(this_limb, x_target);
product = builder.mul_extension(product, diff);
}
constraints.push(product);
combined_limbs = builder.mul_add_extension(limb_base, combined_limbs, this_limb);
}
constraints.push(builder.sub_extension(combined_limbs, output_result));
// Range-check output_borrow to be one bit.
let one = builder.one_extension();
let not_borrow = builder.sub_extension(one, output_borrow);
constraints.push(builder.mul_extension(output_borrow, not_borrow));
}
constraints
}
fn generators(
&self,
gate_index: usize,
_local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
(0..self.num_ops)
.map(|i| {
let g: Box<dyn WitnessGenerator<F>> = Box::new(
U32SubtractionGenerator {
gate: *self,
gate_index,
i,
_phantom: PhantomData,
}
.adapter(),
);
g
})
.collect()
}
fn num_wires(&self) -> usize {
self.num_ops * (5 + Self::num_limbs())
}
fn num_constants(&self) -> usize {
0
}
fn degree(&self) -> usize {
1 << Self::limb_bits()
}
fn num_constraints(&self) -> usize {
self.num_ops * (3 + Self::num_limbs())
}
}
#[derive(Clone, Debug)]
struct U32SubtractionGenerator<F: RichField + Extendable<D>, const D: usize> {
gate: U32SubtractionGate<F, D>,
gate_index: usize,
i: usize,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
for U32SubtractionGenerator<F, D>
{
fn dependencies(&self) -> Vec<Target> {
let local_target = |input| Target::wire(self.gate_index, input);
vec![
local_target(self.gate.wire_ith_input_x(self.i)),
local_target(self.gate.wire_ith_input_y(self.i)),
local_target(self.gate.wire_ith_input_borrow(self.i)),
]
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let local_wire = |input| Wire {
gate: self.gate_index,
input,
};
let get_local_wire = |input| witness.get_wire(local_wire(input));
let input_x = get_local_wire(self.gate.wire_ith_input_x(self.i));
let input_y = get_local_wire(self.gate.wire_ith_input_y(self.i));
let input_borrow = get_local_wire(self.gate.wire_ith_input_borrow(self.i));
let result_initial = input_x - input_y - input_borrow;
let result_initial_u64 = result_initial.to_canonical_u64();
let output_borrow = if result_initial_u64 > 1 << 32u64 {
F::ONE
} else {
F::ZERO
};
let base = F::from_canonical_u64(1 << 32u64);
let output_result = result_initial + base * output_borrow;
let output_result_wire = local_wire(self.gate.wire_ith_output_result(self.i));
let output_borrow_wire = local_wire(self.gate.wire_ith_output_borrow(self.i));
out_buffer.set_wire(output_result_wire, output_result);
out_buffer.set_wire(output_borrow_wire, output_borrow);
let output_result_u64 = output_result.to_canonical_u64();
let num_limbs = U32SubtractionGate::<F, D>::num_limbs();
let limb_base = 1 << U32SubtractionGate::<F, D>::limb_bits();
let output_limbs: Vec<_> = (0..num_limbs)
.scan(output_result_u64, |acc, _| {
let tmp = *acc % limb_base;
*acc /= limb_base;
Some(F::from_canonical_u64(tmp))
})
.collect();
for j in 0..num_limbs {
let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j));
out_buffer.set_wire(wire, output_limbs[j]);
}
}
}
#[cfg(test)]
mod tests {
use std::marker::PhantomData;
use anyhow::Result;
use rand::Rng;
use crate::field::extension_field::quartic::QuarticExtension;
use crate::field::field_types::{Field, PrimeField};
use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::gate::Gate;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::gates::subtraction_u32::U32SubtractionGate;
use crate::hash::hash_types::HashOut;
use crate::plonk::vars::EvaluationVars;
#[test]
fn low_degree() {
test_low_degree::<GoldilocksField, _, 4>(U32SubtractionGate::<GoldilocksField, 4> {
num_ops: 3,
_phantom: PhantomData,
})
}
#[test]
fn eval_fns() -> Result<()> {
test_eval_fns::<GoldilocksField, _, 4>(U32SubtractionGate::<GoldilocksField, 4> {
num_ops: 3,
_phantom: PhantomData,
})
}
#[test]
fn test_gate_constraint() {
type F = GoldilocksField;
type FF = QuarticExtension<GoldilocksField>;
const D: usize = 4;
const NUM_U32_SUBTRACTION_OPS: usize = 3;
fn get_wires(inputs_x: Vec<u64>, inputs_y: Vec<u64>, borrows: Vec<u64>) -> Vec<FF> {
let mut v0 = Vec::new();
let mut v1 = Vec::new();
let limb_bits = U32SubtractionGate::<F, D>::limb_bits();
let num_limbs = U32SubtractionGate::<F, D>::num_limbs();
let limb_base = 1 << limb_bits;
for c in 0..NUM_U32_SUBTRACTION_OPS {
let input_x = F::from_canonical_u64(inputs_x[c]);
let input_y = F::from_canonical_u64(inputs_y[c]);
let input_borrow = F::from_canonical_u64(borrows[c]);
let result_initial = input_x - input_y - input_borrow;
let result_initial_u64 = result_initial.to_canonical_u64();
let output_borrow = if result_initial_u64 > 1 << 32u64 {
F::ONE
} else {
F::ZERO
};
let base = F::from_canonical_u64(1 << 32u64);
let output_result = result_initial + base * output_borrow;
let output_result_u64 = output_result.to_canonical_u64();
let mut output_limbs: Vec<_> = (0..num_limbs)
.scan(output_result_u64, |acc, _| {
let tmp = *acc % limb_base;
*acc /= limb_base;
Some(F::from_canonical_u64(tmp))
})
.collect();
v0.push(input_x);
v0.push(input_y);
v0.push(input_borrow);
v0.push(output_result);
v0.push(output_borrow);
v1.append(&mut output_limbs);
}
v0.iter()
.chain(v1.iter())
.map(|&x| x.into())
.collect::<Vec<_>>()
}
let mut rng = rand::thread_rng();
let inputs_x = (0..NUM_U32_SUBTRACTION_OPS)
.map(|_| rng.gen::<u32>() as u64)
.collect();
let inputs_y = (0..NUM_U32_SUBTRACTION_OPS)
.map(|_| rng.gen::<u32>() as u64)
.collect();
let borrows = (0..NUM_U32_SUBTRACTION_OPS)
.map(|_| (rng.gen::<u32>() % 2) as u64)
.collect();
let gate = U32SubtractionGate::<F, D> {
num_ops: NUM_U32_SUBTRACTION_OPS,
_phantom: PhantomData,
};
let vars = EvaluationVars {
local_constants: &[],
local_wires: &get_wires(inputs_x, inputs_y, borrows),
public_inputs_hash: &HashOut::rand(),
};
assert!(
gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()),
"Gate constraints are not satisfied."
);
}
}

View File

@ -1,5 +1,7 @@
#![allow(clippy::assertions_on_constants)]
use std::arch::aarch64::*; use std::arch::aarch64::*;
use std::convert::TryInto; use std::arch::asm;
use static_assertions::const_assert; use static_assertions::const_assert;
use unroll::unroll_for_loops; use unroll::unroll_for_loops;
@ -172,9 +174,7 @@ unsafe fn multiply(x: u64, y: u64) -> u64 {
let xy_hi_lo_mul_epsilon = mul_epsilon(xy_hi); let xy_hi_lo_mul_epsilon = mul_epsilon(xy_hi);
// add_with_wraparound is safe, as xy_hi_lo_mul_epsilon <= 0xfffffffe00000001 <= ORDER. // add_with_wraparound is safe, as xy_hi_lo_mul_epsilon <= 0xfffffffe00000001 <= ORDER.
let res1 = add_with_wraparound(res0, xy_hi_lo_mul_epsilon); add_with_wraparound(res0, xy_hi_lo_mul_epsilon)
res1
} }
// ==================================== STANDALONE CONST LAYER ===================================== // ==================================== STANDALONE CONST LAYER =====================================
@ -267,9 +267,7 @@ unsafe fn mds_reduce(
// Multiply by EPSILON and accumulate. // Multiply by EPSILON and accumulate.
let res_unadj = vmlal_laneq_u32::<0>(res_lo, res_hi_hi, mds_consts0); let res_unadj = vmlal_laneq_u32::<0>(res_lo, res_hi_hi, mds_consts0);
let res_adj = vcgtq_u64(res_lo, res_unadj); let res_adj = vcgtq_u64(res_lo, res_unadj);
let res = vsraq_n_u64::<32>(res_unadj, res_adj); vsraq_n_u64::<32>(res_unadj, res_adj)
res
} }
#[inline(always)] #[inline(always)]
@ -969,8 +967,7 @@ unsafe fn partial_round(
#[inline(always)] #[inline(always)]
unsafe fn full_round(state: [u64; 12], round_constants: &[u64; WIDTH]) -> [u64; 12] { unsafe fn full_round(state: [u64; 12], round_constants: &[u64; WIDTH]) -> [u64; 12] {
let state = sbox_layer_full(state); let state = sbox_layer_full(state);
let state = mds_const_layers_full(state, round_constants); mds_const_layers_full(state, round_constants)
state
} }
#[inline] #[inline]

View File

@ -1,5 +1,4 @@
use core::arch::x86_64::*; use core::arch::x86_64::*;
use std::convert::TryInto;
use std::mem::size_of; use std::mem::size_of;
use static_assertions::const_assert; use static_assertions::const_assert;

View File

@ -1,5 +1,3 @@
use std::convert::TryInto;
use rand::Rng; use rand::Rng;
use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde::{Deserialize, Deserializer, Serialize, Serializer};

View File

@ -1,7 +1,5 @@
//! Concrete instantiation of a hash function. //! Concrete instantiation of a hash function.
use std::convert::TryInto;
use crate::field::extension_field::Extendable; use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField; use crate::field::field_types::RichField;
use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::hash::hash_types::{HashOut, HashOutTarget};
@ -131,10 +129,8 @@ pub fn hash_n_to_m<F: RichField, P: PlonkyPermutation<F>>(
// Absorb all input chunks. // Absorb all input chunks.
for input_chunk in inputs.chunks(SPONGE_RATE) { for input_chunk in inputs.chunks(SPONGE_RATE) {
for i in 0..input_chunk.len() { state[..input_chunk.len()].copy_from_slice(input_chunk);
state[i] = input_chunk[i]; state = permute(state);
}
state = P::permute(state);
} }
// Squeeze until we have the desired number of outputs. // Squeeze until we have the desired number of outputs.

View File

@ -1,5 +1,3 @@
use std::convert::TryInto;
use anyhow::{ensure, Result}; use anyhow::{ensure, Result};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -55,6 +53,7 @@ pub(crate) fn verify_merkle_proof<F: RichField, H: Hasher<F>>(
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> { impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Verifies that the given leaf data is present at the given index in the Merkle tree with the /// Verifies that the given leaf data is present at the given index in the Merkle tree with the
/// given cap. The index is given by it's little-endian bits. /// given cap. The index is given by it's little-endian bits.
#[cfg(test)]
pub(crate) fn verify_merkle_proof<H: AlgebraicHasher<F>>( pub(crate) fn verify_merkle_proof<H: AlgebraicHasher<F>>(
&mut self, &mut self,
leaf_data: Vec<Target>, leaf_data: Vec<Target>,
@ -94,7 +93,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
proof: &MerkleProofTarget, proof: &MerkleProofTarget,
) { ) {
let zero = self.zero(); let zero = self.zero();
let mut state: HashOutTarget = self.hash_or_noop::<H>(leaf_data); let mut state:HashOutTarget = self.hash_or_noop(leaf_data);
for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) {
let mut perm_inputs = [zero; SPONGE_WIDTH]; let mut perm_inputs = [zero; SPONGE_WIDTH];
@ -116,7 +115,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
} }
} }
pub fn assert_hashes_equal(&mut self, x: HashOutTarget, y: HashOutTarget) { pub fn connect_hashes(&mut self, x: HashOutTarget, y: HashOutTarget) {
for i in 0..4 { for i in 0..4 {
self.connect(x.elements[i], y.elements[i]); self.connect(x.elements[i], y.elements[i]);
} }

View File

@ -1,8 +1,6 @@
//! Implementation of the Poseidon hash function, as described in //! Implementation of the Poseidon hash function, as described in
//! https://eprint.iacr.org/2019/458.pdf //! https://eprint.iacr.org/2019/458.pdf
use std::convert::TryInto;
use unroll::unroll_for_loops; use unroll::unroll_for_loops;
use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::target::ExtensionTarget;
@ -452,9 +450,10 @@ pub trait Poseidon: PrimeField {
s0, s0,
); );
for i in 1..WIDTH { for i in 1..WIDTH {
let t = <Self as Poseidon>::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]; let t = <Self as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_W_HATS[r][i - 1];
let t = Self::from_canonical_u64(t); let t = Self::Extension::from_canonical_u64(t);
d = builder.mul_const_add_extension(t, state[i], d); let t = builder.constant_extension(t);
d = builder.mul_add_extension(t, state[i], d);
} }
let mut result = [builder.zero_extension(); WIDTH]; let mut result = [builder.zero_extension(); WIDTH];
@ -624,6 +623,7 @@ pub trait Poseidon: PrimeField {
} }
} }
#[cfg(test)]
pub(crate) mod test_helpers { pub(crate) mod test_helpers {
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::hash::hashing::SPONGE_WIDTH; use crate::hash::hashing::SPONGE_WIDTH;

View File

@ -1,7 +1,7 @@
use crate::field::extension_field::target::ExtensionTarget;
use std::convert::TryInto; use std::convert::TryInto;
use std::marker::PhantomData; use std::marker::PhantomData;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::RichField; use crate::field::field_types::RichField;
use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget}; use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget};

View File

@ -1,9 +1,15 @@
use std::fmt::Debug; use std::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;
use num::BigUint;
use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::field::field_types::{Field, RichField};
use crate::gadgets::arithmetic_u32::U32Target;
use crate::gadgets::biguint::BigUintTarget;
use crate::gadgets::nonnative::NonNativeTarget;
use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::hash::hash_types::{HashOut, HashOutTarget};
use crate::iop::target::Target; use crate::iop::target::Target;
use crate::iop::wire::Wire; use crate::iop::wire::Wire;
@ -89,7 +95,7 @@ pub(crate) fn generate_partial_witness<
assert_eq!( assert_eq!(
remaining_generators, 0, remaining_generators, 0,
"{} generators weren't run", "{} generators weren't run",
remaining_generators remaining_generators,
); );
witness witness
@ -156,6 +162,24 @@ impl<F: Field> GeneratedValues<F> {
self.target_values.push((target, value)) self.target_values.push((target, value))
} }
fn set_u32_target(&mut self, target: U32Target, value: u32) {
self.set_target(target.0, F::from_canonical_u32(value))
}
pub fn set_biguint_target(&mut self, target: BigUintTarget, value: BigUint) {
let mut limbs = value.to_u32_digits();
assert!(target.num_limbs() >= limbs.len());
limbs.resize(target.num_limbs(), 0);
for i in 0..target.num_limbs() {
self.set_u32_target(target.get_limb(i), limbs[i]);
}
}
pub fn set_nonnative_target<FF: Field>(&mut self, target: NonNativeTarget<FF>, value: FF) {
self.set_biguint_target(target.value, value.to_biguint())
}
pub fn set_hash_target(&mut self, ht: HashOutTarget, value: HashOut<F>) { pub fn set_hash_target(&mut self, ht: HashOutTarget, value: HashOut<F>) {
ht.elements ht.elements
.iter() .iter()

View File

@ -41,6 +41,7 @@ impl Target {
/// A `Target` which has already been constrained such that it can only be 0 or 1. /// A `Target` which has already been constrained such that it can only be 0 or 1.
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
#[allow(clippy::manual_non_exhaustive)]
pub struct BoolTarget { pub struct BoolTarget {
pub target: Target, pub target: Target,
/// This private field is here to force all instantiations to go through `new_unsafe`. /// This private field is here to force all instantiations to go through `new_unsafe`.

View File

@ -1,9 +1,14 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::convert::TryInto;
use num::{BigUint, FromPrimitive, Zero};
use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::{Field, RichField}; use crate::field::field_types::{Field, RichField};
use crate::field::field_types::Field;
use crate::gadgets::arithmetic_u32::U32Target;
use crate::gadgets::biguint::BigUintTarget;
use crate::gadgets::nonnative::NonNativeTarget;
use crate::hash::hash_types::HashOutTarget; use crate::hash::hash_types::HashOutTarget;
use crate::hash::hash_types::{HashOut, MerkleCapTarget}; use crate::hash::hash_types::{HashOut, MerkleCapTarget};
use crate::hash::merkle_tree::MerkleCap; use crate::hash::merkle_tree::MerkleCap;
@ -54,6 +59,24 @@ pub trait Witness<F: Field> {
panic!("not a bool") panic!("not a bool")
} }
fn get_biguint_target(&self, target: BigUintTarget) -> BigUint {
let mut result = BigUint::zero();
let limb_base = BigUint::from_u64(1 << 32u64).unwrap();
for i in (0..target.num_limbs()).rev() {
let limb = target.get_limb(i);
result *= &limb_base;
result += self.get_target(limb.0).to_biguint();
}
result
}
fn get_nonnative_target<FF: Field>(&self, target: NonNativeTarget<FF>) -> FF {
let val = self.get_biguint_target(target.value);
FF::from_biguint(val)
}
fn get_hash_target(&self, ht: HashOutTarget) -> HashOut<F> { fn get_hash_target(&self, ht: HashOutTarget) -> HashOut<F> {
HashOut { HashOut {
elements: self.get_targets(&ht.elements).try_into().unwrap(), elements: self.get_targets(&ht.elements).try_into().unwrap(),
@ -122,6 +145,16 @@ pub trait Witness<F: Field> {
self.set_target(target.target, F::from_bool(value)) self.set_target(target.target, F::from_bool(value))
} }
fn set_u32_target(&mut self, target: U32Target, value: u32) {
self.set_target(target.0, F::from_canonical_u32(value))
}
fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) {
for (&lt, &l) in target.limbs.iter().zip(&value.to_u32_digits()) {
self.set_u32_target(lt, l);
}
}
fn set_wire(&mut self, wire: Wire, value: F) { fn set_wire(&mut self, wire: Wire, value: F) {
self.set_target(Target::Wire(wire), value) self.set_target(Target::Wire(wire), value)
} }

View File

@ -1,9 +1,15 @@
#![feature(asm)] #![allow(incomplete_features)]
#![feature(destructuring_assignment)] #![allow(const_evaluatable_unchecked)]
#![allow(clippy::new_without_default)]
#![allow(clippy::too_many_arguments)]
#![allow(clippy::len_without_is_empty)]
#![allow(clippy::needless_range_loop)]
#![feature(asm_sym)]
#![feature(generic_const_exprs)] #![feature(generic_const_exprs)]
#![feature(specialization)] #![feature(specialization)]
#![feature(stdsimd)] #![feature(stdsimd)]
pub mod curve;
pub mod field; pub mod field;
pub mod fri; pub mod fri;
pub mod gadgets; pub mod gadgets;
@ -13,3 +19,11 @@ pub mod iop;
pub mod plonk; pub mod plonk;
pub mod polynomial; pub mod polynomial;
pub mod util; pub mod util;
// Set up Jemalloc
#[cfg(not(target_env = "msvc"))]
use jemallocator::Jemalloc;
#[cfg(not(target_env = "msvc"))]
#[global_allocator]
static GLOBAL: Jemalloc = Jemalloc;

View File

@ -1,6 +1,5 @@
use std::cmp::max; use std::cmp::max;
use std::collections::{BTreeMap, HashMap, HashSet}; use std::collections::{BTreeMap, HashMap, HashSet};
use std::convert::TryInto;
use std::time::Instant; use std::time::Instant;
use log::{debug, info, Level}; use log::{debug, info, Level};
@ -12,14 +11,20 @@ use crate::field::fft::fft_root_table;
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::fri::commitment::PolynomialBatchCommitment; use crate::fri::commitment::PolynomialBatchCommitment;
use crate::fri::{FriConfig, FriParams}; use crate::fri::{FriConfig, FriParams};
use crate::gadgets::arithmetic_extension::ArithmeticOperation; use crate::gadgets::arithmetic::BaseArithmeticOperation;
use crate::gates::arithmetic::ArithmeticExtensionGate; use crate::gadgets::arithmetic_extension::ExtensionArithmeticOperation;
use crate::gadgets::arithmetic_u32::U32Target;
use crate::gates::arithmetic_base::ArithmeticGate;
use crate::gates::arithmetic_extension::ArithmeticExtensionGate;
use crate::gates::arithmetic_u32::U32ArithmeticGate;
use crate::gates::constant::ConstantGate; use crate::gates::constant::ConstantGate;
use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate};
use crate::gates::gate_tree::Tree; use crate::gates::gate_tree::Tree;
use crate::gates::multiplication_extension::MulExtensionGate;
use crate::gates::noop::NoopGate; use crate::gates::noop::NoopGate;
use crate::gates::public_input::PublicInputGate; use crate::gates::public_input::PublicInputGate;
use crate::gates::random_access::RandomAccessGate; use crate::gates::random_access::RandomAccessGate;
use crate::gates::subtraction_u32::U32SubtractionGate;
use crate::gates::switch::SwitchGate; use crate::gates::switch::SwitchGate;
use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget}; use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget};
use crate::iop::generator::{ use crate::iop::generator::{
@ -35,7 +40,7 @@ use crate::plonk::config::{GenericConfig, Hasher};
use crate::plonk::copy_constraint::CopyConstraint; use crate::plonk::copy_constraint::CopyConstraint;
use crate::plonk::permutation_argument::Forest; use crate::plonk::permutation_argument::Forest;
use crate::plonk::plonk_common::PlonkPolynomials; use crate::plonk::plonk_common::PlonkPolynomials;
use crate::polynomial::polynomial::PolynomialValues; use crate::polynomial::PolynomialValues;
use crate::util::context_tree::ContextTree; use crate::util::context_tree::ContextTree;
use crate::util::marking::{Markable, MarkedTargets}; use crate::util::marking::{Markable, MarkedTargets};
use crate::util::partial_products::num_partial_products; use crate::util::partial_products::num_partial_products;
@ -71,24 +76,13 @@ pub struct CircuitBuilder<F: Extendable<D>, const D: usize> {
constants_to_targets: HashMap<F, Target>, constants_to_targets: HashMap<F, Target>,
targets_to_constants: HashMap<Target, F>, targets_to_constants: HashMap<Target, F>,
/// Memoized results of `arithmetic` calls.
pub(crate) base_arithmetic_results: HashMap<BaseArithmeticOperation<F>, Target>,
/// Memoized results of `arithmetic_extension` calls. /// Memoized results of `arithmetic_extension` calls.
pub(crate) arithmetic_results: HashMap<ArithmeticOperation<F, D>, ExtensionTarget<D>>, pub(crate) arithmetic_results: HashMap<ExtensionArithmeticOperation<F, D>, ExtensionTarget<D>>,
/// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using batched_gates: BatchedGates<F, D>,
/// these constants with gate index `g` and already using `i` arithmetic operations.
pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>,
/// A map `(c0, c1) -> (g, i)` from constants `vec_size` to an available arithmetic gate using
/// these constants with gate index `g` and already using `i` random accesses.
pub(crate) free_random_access: HashMap<usize, (usize, usize)>,
// `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value
// chunk_size, and contains `(g, i, c)`, if the gate `g`, at index `i`, already contains `c` copies
// of switches
pub(crate) current_switch_gates: Vec<Option<(SwitchGate<F, D>, usize, usize)>>,
/// An available `ConstantGate` instance, if any.
free_constant: Option<(usize, usize)>,
} }
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> { impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
@ -104,12 +98,10 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
marked_targets: Vec::new(), marked_targets: Vec::new(),
generators: Vec::new(), generators: Vec::new(),
constants_to_targets: HashMap::new(), constants_to_targets: HashMap::new(),
base_arithmetic_results: HashMap::new(),
arithmetic_results: HashMap::new(), arithmetic_results: HashMap::new(),
targets_to_constants: HashMap::new(), targets_to_constants: HashMap::new(),
free_arithmetic: HashMap::new(), batched_gates: BatchedGates::new(),
free_random_access: HashMap::new(),
current_switch_gates: Vec::new(),
free_constant: None,
}; };
builder.check_config(); builder.check_config();
builder builder
@ -216,6 +208,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
gate_ref, gate_ref,
constants, constants,
}); });
index index
} }
@ -260,6 +253,11 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.connect(x, zero); self.connect(x, zero);
} }
pub fn assert_one(&mut self, x: Target) {
let one = self.one();
self.connect(x, one);
}
pub fn add_generators(&mut self, generators: Vec<Box<dyn WitnessGenerator<F>>>) { pub fn add_generators(&mut self, generators: Vec<Box<dyn WitnessGenerator<F>>>) {
self.generators.extend(generators); self.generators.extend(generators);
} }
@ -313,26 +311,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
target target
} }
/// Returns the gate index and copy index of a free `ConstantGate` slot, potentially adding a
/// new `ConstantGate` if needed.
fn constant_gate_instance(&mut self) -> (usize, usize) {
if self.free_constant.is_none() {
let num_consts = self.config.constant_gate_size;
// We will fill this `ConstantGate` with zero constants initially.
// These will be overwritten by `constant` as the gate instances are filled.
let gate = self.add_gate(ConstantGate { num_consts }, vec![F::ZERO; num_consts]);
self.free_constant = Some((gate, 0));
}
let (gate, instance) = self.free_constant.unwrap();
if instance + 1 < self.config.constant_gate_size {
self.free_constant = Some((gate, instance + 1));
} else {
self.free_constant = None;
}
(gate, instance)
}
pub fn constants(&mut self, constants: &[F]) -> Vec<Target> { pub fn constants(&mut self, constants: &[F]) -> Vec<Target> {
constants.iter().map(|&c| self.constant(c)).collect() constants.iter().map(|&c| self.constant(c)).collect()
} }
@ -345,6 +323,11 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
} }
} }
/// Returns a U32Target for the value `c`, which is assumed to be at most 32 bits.
pub fn constant_u32(&mut self, c: u32) -> U32Target {
U32Target(self.constant(F::from_canonical_u32(c)))
}
/// If the given target is a constant (i.e. it was created by the `constant(F)` method), returns /// If the given target is a constant (i.e. it was created by the `constant(F)` method), returns
/// its constant value. Otherwise, returns `None`. /// its constant value. Otherwise, returns `None`.
pub fn target_as_constant(&self, target: Target) -> Option<F> { pub fn target_as_constant(&self, target: Target) -> Option<F> {
@ -396,6 +379,20 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
} }
} }
/// The number of (base field) `arithmetic` operations that can be performed in a single gate.
pub(crate) fn num_base_arithmetic_ops_per_gate(&self) -> usize {
if self.config.use_base_arithmetic_gate {
ArithmeticGate::new_from_config(&self.config).num_ops
} else {
self.num_ext_arithmetic_ops_per_gate()
}
}
/// The number of `arithmetic_extension` operations that can be performed in a single gate.
pub(crate) fn num_ext_arithmetic_ops_per_gate(&self) -> usize {
ArithmeticExtensionGate::<D>::new_from_config(&self.config).num_ops
}
/// The number of polynomial values that will be revealed per opening, both for the "regular" /// The number of polynomial values that will be revealed per opening, both for the "regular"
/// polynomials and for the Z polynomials. Because calculating these values involves a recursive /// polynomials and for the Z polynomials. Because calculating these values involves a recursive
/// dependence (the amount of blinding depends on the degree, which depends on the blinding), /// dependence (the amount of blinding depends on the degree, which depends on the blinding),
@ -566,76 +563,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
) )
} }
/// Fill the remaining unused arithmetic operations with zeros, so that all
/// `ArithmeticExtensionGenerator` are run.
fn fill_arithmetic_gates(&mut self) {
let zero = self.zero_extension();
let remaining_arithmetic_gates = self.free_arithmetic.values().copied().collect::<Vec<_>>();
for (gate, i) in remaining_arithmetic_gates {
for j in i..ArithmeticExtensionGate::<D>::num_ops(&self.config) {
let wires_multiplicand_0 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_ith_multiplicand_0(j),
);
let wires_multiplicand_1 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_ith_multiplicand_1(j),
);
let wires_addend = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_ith_addend(j),
);
self.connect_extension(zero, wires_multiplicand_0);
self.connect_extension(zero, wires_multiplicand_1);
self.connect_extension(zero, wires_addend);
}
}
}
/// Fill the remaining unused random access operations with zeros, so that all
/// `RandomAccessGenerator`s are run.
fn fill_random_access_gates(&mut self) {
let zero = self.zero();
for (vec_size, (_, i)) in self.free_random_access.clone() {
let max_copies = RandomAccessGate::<F, D>::max_num_copies(
self.config.num_routed_wires,
self.config.num_wires,
vec_size,
);
for _ in i..max_copies {
self.random_access(zero, zero, vec![zero; vec_size]);
}
}
}
/// Fill the remaining unused switch gates with dummy values, so that all
/// `SwitchGenerator` are run.
fn fill_switch_gates(&mut self) {
let zero = self.zero();
for chunk_size in 1..=self.current_switch_gates.len() {
if let Some((gate, gate_index, mut copy)) =
self.current_switch_gates[chunk_size - 1].clone()
{
while copy < gate.num_copies {
for element in 0..chunk_size {
let wire_first_input =
Target::wire(gate_index, gate.wire_first_input(copy, element));
let wire_second_input =
Target::wire(gate_index, gate.wire_second_input(copy, element));
let wire_switch_bool =
Target::wire(gate_index, gate.wire_switch_bool(copy));
self.connect(zero, wire_first_input);
self.connect(zero, wire_second_input);
self.connect(zero, wire_switch_bool);
}
copy += 1;
}
}
}
}
pub fn print_gate_counts(&self, min_delta: usize) { pub fn print_gate_counts(&self, min_delta: usize) {
// Print gate counts for each context. // Print gate counts for each context.
self.context_log self.context_log
@ -659,9 +586,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let mut timing = TimingTree::new("preprocess", Level::Trace); let mut timing = TimingTree::new("preprocess", Level::Trace);
let start = Instant::now(); let start = Instant::now();
self.fill_arithmetic_gates(); self.fill_batched_gates();
self.fill_random_access_gates();
self.fill_switch_gates();
// Hash the public inputs, and route them to a `PublicInputGate` which will enforce that // Hash the public inputs, and route them to a `PublicInputGate` which will enforce that
// those hash wires match the claimed public inputs. // those hash wires match the claimed public inputs.
@ -698,7 +623,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
..=1 << self.config.rate_bits) ..=1 << self.config.rate_bits)
.min_by_key(|&q| num_partial_products(self.config.num_routed_wires, q).0 + q) .min_by_key(|&q| num_partial_products(self.config.num_routed_wires, q).0 + q)
.unwrap(); .unwrap();
info!("Quotient degree factor set to: {}.", quotient_degree_factor); debug!("Quotient degree factor set to: {}.", quotient_degree_factor);
let prefixed_gates = PrefixedGate::from_tree(gate_tree); let prefixed_gates = PrefixedGate::from_tree(gate_tree);
let subgroup = F::two_adic_subgroup(degree_bits); let subgroup = F::two_adic_subgroup(degree_bits);
@ -710,7 +635,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
// Precompute FFT roots. // Precompute FFT roots.
let max_fft_points = let max_fft_points =
1 << degree_bits + max(self.config.rate_bits, log2_ceil(quotient_degree_factor)); 1 << (degree_bits + max(self.config.rate_bits, log2_ceil(quotient_degree_factor)));
let fft_root_table = fft_root_table(max_fft_points); let fft_root_table = fft_root_table(max_fft_points);
let constants_sigmas_vecs = [constant_vecs, sigma_vecs.clone()].concat(); let constants_sigmas_vecs = [constant_vecs, sigma_vecs.clone()].concat();
@ -745,7 +670,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let watch_rep_index = forest.parents[watch_index]; let watch_rep_index = forest.parents[watch_index];
generator_indices_by_watches generator_indices_by_watches
.entry(watch_rep_index) .entry(watch_rep_index)
.or_insert(vec![]) .or_insert_with(Vec::new)
.push(i); .push(i);
} }
} }
@ -801,7 +726,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
circuit_digest, circuit_digest,
}; };
info!("Building circuit took {}s", start.elapsed().as_secs_f32()); debug!("Building circuit took {}s", start.elapsed().as_secs_f32());
CircuitData { CircuitData {
prover_only, prover_only,
verifier_only, verifier_only,
@ -837,3 +762,386 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
} }
} }
} }
/// Various gate types can contain multiple copies in a single Gate. This helper struct lets a
/// CircuitBuilder track such gates that are currently being "filled up."
pub struct BatchedGates<F: RichField + Extendable<D>, const D: usize> {
/// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using
/// these constants with gate index `g` and already using `i` arithmetic operations.
pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>,
pub(crate) free_base_arithmetic: HashMap<(F, F), (usize, usize)>,
pub(crate) free_mul: HashMap<F, (usize, usize)>,
/// A map `b -> (g, i)` from `b` bits to an available random access gate of that size with gate
/// index `g` and already using `i` random accesses.
pub(crate) free_random_access: HashMap<usize, (usize, usize)>,
/// `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value
/// chunk_size, and contains `(g, i, c)`, if the gate `g`, at index `i`, already contains `c` copies
/// of switches
pub(crate) current_switch_gates: Vec<Option<(SwitchGate<F, D>, usize, usize)>>,
/// The `U32ArithmeticGate` currently being filled (so new u32 arithmetic operations will be added to this gate before creating a new one)
pub(crate) current_u32_arithmetic_gate: Option<(usize, usize)>,
/// The `U32SubtractionGate` currently being filled (so new u32 subtraction operations will be added to this gate before creating a new one)
pub(crate) current_u32_subtraction_gate: Option<(usize, usize)>,
/// An available `ConstantGate` instance, if any.
pub(crate) free_constant: Option<(usize, usize)>,
}
impl<F: RichField + Extendable<D>, const D: usize> BatchedGates<F, D> {
pub fn new() -> Self {
Self {
free_arithmetic: HashMap::new(),
free_base_arithmetic: HashMap::new(),
free_mul: HashMap::new(),
free_random_access: HashMap::new(),
current_switch_gates: Vec::new(),
current_u32_arithmetic_gate: None,
current_u32_subtraction_gate: None,
free_constant: None,
}
}
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Finds the last available arithmetic gate with the given constants or add one if there aren't any.
/// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index
/// `g` and the gate's `i`-th operation is available.
pub(crate) fn find_base_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) {
let (gate, i) = self
.batched_gates
.free_base_arithmetic
.get(&(const_0, const_1))
.copied()
.unwrap_or_else(|| {
let gate = self.add_gate(
ArithmeticGate::new_from_config(&self.config),
vec![const_0, const_1],
);
(gate, 0)
});
// Update `free_arithmetic` with new values.
if i < ArithmeticGate::num_ops(&self.config) - 1 {
self.batched_gates
.free_base_arithmetic
.insert((const_0, const_1), (gate, i + 1));
} else {
self.batched_gates
.free_base_arithmetic
.remove(&(const_0, const_1));
}
(gate, i)
}
/// Finds the last available arithmetic gate with the given constants or add one if there aren't any.
/// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index
/// `g` and the gate's `i`-th operation is available.
pub(crate) fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) {
let (gate, i) = self
.batched_gates
.free_arithmetic
.get(&(const_0, const_1))
.copied()
.unwrap_or_else(|| {
let gate = self.add_gate(
ArithmeticExtensionGate::new_from_config(&self.config),
vec![const_0, const_1],
);
(gate, 0)
});
// Update `free_arithmetic` with new values.
if i < ArithmeticExtensionGate::<D>::num_ops(&self.config) - 1 {
self.batched_gates
.free_arithmetic
.insert((const_0, const_1), (gate, i + 1));
} else {
self.batched_gates
.free_arithmetic
.remove(&(const_0, const_1));
}
(gate, i)
}
/// Finds the last available arithmetic gate with the given constants or add one if there aren't any.
/// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index
/// `g` and the gate's `i`-th operation is available.
pub(crate) fn find_mul_gate(&mut self, const_0: F) -> (usize, usize) {
let (gate, i) = self
.batched_gates
.free_mul
.get(&const_0)
.copied()
.unwrap_or_else(|| {
let gate = self.add_gate(
MulExtensionGate::new_from_config(&self.config),
vec![const_0],
);
(gate, 0)
});
// Update `free_arithmetic` with new values.
if i < MulExtensionGate::<D>::num_ops(&self.config) - 1 {
self.batched_gates.free_mul.insert(const_0, (gate, i + 1));
} else {
self.batched_gates.free_mul.remove(&const_0);
}
(gate, i)
}
/// Finds the last available random access gate with the given `vec_size` or add one if there aren't any.
/// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index
/// `g` and the gate's `i`-th random access is available.
pub(crate) fn find_random_access_gate(&mut self, bits: usize) -> (usize, usize) {
let (gate, i) = self
.batched_gates
.free_random_access
.get(&bits)
.copied()
.unwrap_or_else(|| {
let gate = self.add_gate(
RandomAccessGate::new_from_config(&self.config, bits),
vec![],
);
(gate, 0)
});
// Update `free_random_access` with new values.
if i + 1 < RandomAccessGate::<F, D>::new_from_config(&self.config, bits).num_copies {
self.batched_gates
.free_random_access
.insert(bits, (gate, i + 1));
} else {
self.batched_gates.free_random_access.remove(&bits);
}
(gate, i)
}
pub(crate) fn find_switch_gate(
&mut self,
chunk_size: usize,
) -> (SwitchGate<F, D>, usize, usize) {
if self.batched_gates.current_switch_gates.len() < chunk_size {
self.batched_gates.current_switch_gates.extend(vec![
None;
chunk_size
- self
.batched_gates
.current_switch_gates
.len()
]);
}
let (gate, gate_index, next_copy) =
match self.batched_gates.current_switch_gates[chunk_size - 1].clone() {
None => {
let gate = SwitchGate::<F, D>::new_from_config(&self.config, chunk_size);
let gate_index = self.add_gate(gate.clone(), vec![]);
(gate, gate_index, 0)
}
Some((gate, idx, next_copy)) => (gate, idx, next_copy),
};
let num_copies = gate.num_copies;
if next_copy == num_copies - 1 {
self.batched_gates.current_switch_gates[chunk_size - 1] = None;
} else {
self.batched_gates.current_switch_gates[chunk_size - 1] =
Some((gate.clone(), gate_index, next_copy + 1));
}
(gate, gate_index, next_copy)
}
pub(crate) fn find_u32_arithmetic_gate(&mut self) -> (usize, usize) {
let (gate_index, copy) = match self.batched_gates.current_u32_arithmetic_gate {
None => {
let gate = U32ArithmeticGate::new_from_config(&self.config);
let gate_index = self.add_gate(gate, vec![]);
(gate_index, 0)
}
Some((gate_index, copy)) => (gate_index, copy),
};
if copy == U32ArithmeticGate::<F, D>::num_ops(&self.config) - 1 {
self.batched_gates.current_u32_arithmetic_gate = None;
} else {
self.batched_gates.current_u32_arithmetic_gate = Some((gate_index, copy + 1));
}
(gate_index, copy)
}
pub(crate) fn find_u32_subtraction_gate(&mut self) -> (usize, usize) {
let (gate_index, copy) = match self.batched_gates.current_u32_subtraction_gate {
None => {
let gate = U32SubtractionGate::new_from_config(&self.config);
let gate_index = self.add_gate(gate, vec![]);
(gate_index, 0)
}
Some((gate_index, copy)) => (gate_index, copy),
};
if copy == U32SubtractionGate::<F, D>::num_ops(&self.config) - 1 {
self.batched_gates.current_u32_subtraction_gate = None;
} else {
self.batched_gates.current_u32_subtraction_gate = Some((gate_index, copy + 1));
}
(gate_index, copy)
}
/// Returns the gate index and copy index of a free `ConstantGate` slot, potentially adding a
/// new `ConstantGate` if needed.
fn constant_gate_instance(&mut self) -> (usize, usize) {
if self.batched_gates.free_constant.is_none() {
let num_consts = self.config.constant_gate_size;
// We will fill this `ConstantGate` with zero constants initially.
// These will be overwritten by `constant` as the gate instances are filled.
let gate = self.add_gate(ConstantGate { num_consts }, vec![F::ZERO; num_consts]);
self.batched_gates.free_constant = Some((gate, 0));
}
let (gate, instance) = self.batched_gates.free_constant.unwrap();
if instance + 1 < self.config.constant_gate_size {
self.batched_gates.free_constant = Some((gate, instance + 1));
} else {
self.batched_gates.free_constant = None;
}
(gate, instance)
}
/// Fill the remaining unused arithmetic operations with zeros, so that all
/// `ArithmeticGate` are run.
fn fill_base_arithmetic_gates(&mut self) {
let zero = self.zero();
for ((c0, c1), (_gate, i)) in self.batched_gates.free_base_arithmetic.clone() {
for _ in i..ArithmeticGate::num_ops(&self.config) {
// If we directly wire in zero, an optimization will skip doing anything and return
// zero. So we pass in a virtual target and connect it to zero afterward.
let dummy = self.add_virtual_target();
self.arithmetic(c0, c1, dummy, dummy, dummy);
self.connect(dummy, zero);
}
}
assert!(self.batched_gates.free_base_arithmetic.is_empty());
}
/// Fill the remaining unused arithmetic operations with zeros, so that all
/// `ArithmeticExtensionGenerator`s are run.
fn fill_arithmetic_gates(&mut self) {
let zero = self.zero_extension();
for ((c0, c1), (_gate, i)) in self.batched_gates.free_arithmetic.clone() {
for _ in i..ArithmeticExtensionGate::<D>::num_ops(&self.config) {
// If we directly wire in zero, an optimization will skip doing anything and return
// zero. So we pass in a virtual target and connect it to zero afterward.
let dummy = self.add_virtual_extension_target();
self.arithmetic_extension(c0, c1, dummy, dummy, dummy);
self.connect_extension(dummy, zero);
}
}
assert!(self.batched_gates.free_arithmetic.is_empty());
}
/// Fill the remaining unused arithmetic operations with zeros, so that all
/// `ArithmeticExtensionGenerator`s are run.
fn fill_mul_gates(&mut self) {
let zero = self.zero_extension();
for (c0, (_gate, i)) in self.batched_gates.free_mul.clone() {
for _ in i..MulExtensionGate::<D>::num_ops(&self.config) {
// If we directly wire in zero, an optimization will skip doing anything and return
// zero. So we pass in a virtual target and connect it to zero afterward.
let dummy = self.add_virtual_extension_target();
self.arithmetic_extension(c0, F::ZERO, dummy, dummy, zero);
self.connect_extension(dummy, zero);
}
}
assert!(self.batched_gates.free_mul.is_empty());
}
/// Fill the remaining unused random access operations with zeros, so that all
/// `RandomAccessGenerator`s are run.
fn fill_random_access_gates(&mut self) {
let zero = self.zero();
for (bits, (_, i)) in self.batched_gates.free_random_access.clone() {
let max_copies =
RandomAccessGate::<F, D>::new_from_config(&self.config, bits).num_copies;
for _ in i..max_copies {
self.random_access(zero, zero, vec![zero; 1 << bits]);
}
}
}
/// Fill the remaining unused switch gates with dummy values, so that all
/// `SwitchGenerator`s are run.
fn fill_switch_gates(&mut self) {
let zero = self.zero();
for chunk_size in 1..=self.batched_gates.current_switch_gates.len() {
if let Some((gate, gate_index, mut copy)) =
self.batched_gates.current_switch_gates[chunk_size - 1].clone()
{
while copy < gate.num_copies {
for element in 0..chunk_size {
let wire_first_input =
Target::wire(gate_index, gate.wire_first_input(copy, element));
let wire_second_input =
Target::wire(gate_index, gate.wire_second_input(copy, element));
let wire_switch_bool =
Target::wire(gate_index, gate.wire_switch_bool(copy));
self.connect(zero, wire_first_input);
self.connect(zero, wire_second_input);
self.connect(zero, wire_switch_bool);
}
copy += 1;
}
}
}
}
/// Fill the remaining unused U32 arithmetic operations with zeros, so that all
/// `U32ArithmeticGenerator`s are run.
fn fill_u32_arithmetic_gates(&mut self) {
let zero = self.zero_u32();
if let Some((_gate_index, copy)) = self.batched_gates.current_u32_arithmetic_gate {
for _ in copy..U32ArithmeticGate::<F, D>::num_ops(&self.config) {
let dummy = self.add_virtual_u32_target();
self.mul_add_u32(dummy, dummy, dummy);
self.connect_u32(dummy, zero);
}
}
}
/// Fill the remaining unused U32 subtraction operations with zeros, so that all
/// `U32SubtractionGenerator`s are run.
fn fill_u32_subtraction_gates(&mut self) {
let zero = self.zero_u32();
if let Some((_gate_index, copy)) = self.batched_gates.current_u32_subtraction_gate {
for _i in copy..U32SubtractionGate::<F, D>::num_ops(&self.config) {
let dummy = self.add_virtual_u32_target();
self.sub_u32(dummy, dummy, dummy);
self.connect_u32(dummy, zero);
}
}
}
fn fill_batched_gates(&mut self) {
self.fill_arithmetic_gates();
self.fill_base_arithmetic_gates();
self.fill_mul_gates();
self.fill_random_access_gates();
self.fill_switch_gates();
self.fill_u32_arithmetic_gates();
self.fill_u32_subtraction_gates();
}
}

View File

@ -16,6 +16,7 @@ use crate::iop::target::Target;
use crate::iop::witness::PartialWitness; use crate::iop::witness::PartialWitness;
use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::config::{GenericConfig, Hasher};
use crate::plonk::proof::ProofWithPublicInputs; use crate::plonk::proof::ProofWithPublicInputs;
use crate::plonk::proof::{CompressedProofWithPublicInputs, ProofWithPublicInputs};
use crate::plonk::prover::prove; use crate::plonk::prover::prove;
use crate::plonk::verifier::verify; use crate::plonk::verifier::verify;
use crate::util::marking::MarkedTargets; use crate::util::marking::MarkedTargets;
@ -26,6 +27,9 @@ pub struct CircuitConfig {
pub num_wires: usize, pub num_wires: usize,
pub num_routed_wires: usize, pub num_routed_wires: usize,
pub constant_gate_size: usize, pub constant_gate_size: usize,
/// Whether to use a dedicated gate for base field arithmetic, rather than using a single gate
/// for both base field and extension field arithmetic.
pub use_base_arithmetic_gate: bool,
pub security_bits: usize, pub security_bits: usize,
pub rate_bits: usize, pub rate_bits: usize,
/// The number of challenge points to generate, for IOPs that have soundness errors of (roughly) /// The number of challenge points to generate, for IOPs that have soundness errors of (roughly)
@ -45,30 +49,35 @@ impl Default for CircuitConfig {
} }
impl CircuitConfig { impl CircuitConfig {
pub fn rate(&self) -> f64 {
1.0 / ((1 << self.rate_bits) as f64)
}
pub fn num_advice_wires(&self) -> usize { pub fn num_advice_wires(&self) -> usize {
self.num_wires - self.num_routed_wires self.num_wires - self.num_routed_wires
} }
/// A typical recursion config, without zero-knowledge, targeting ~100 bit security. /// A typical recursion config, without zero-knowledge, targeting ~100 bit security.
pub(crate) fn standard_recursion_config() -> Self { pub fn standard_recursion_config() -> Self {
Self { Self {
num_wires: 143, num_wires: 135,
num_routed_wires: 25, num_routed_wires: 80,
constant_gate_size: 6, constant_gate_size: 5,
use_base_arithmetic_gate: true,
security_bits: 100, security_bits: 100,
rate_bits: 3, rate_bits: 3,
num_challenges: 2, num_challenges: 2,
zero_knowledge: false, zero_knowledge: false,
cap_height: 3, cap_height: 4,
fri_config: FriConfig { fri_config: FriConfig {
proof_of_work_bits: 16, proof_of_work_bits: 16,
reduction_strategy: FriReductionStrategy::ConstantArityBits(3, 5), reduction_strategy: FriReductionStrategy::ConstantArityBits(4, 5),
num_query_rounds: 28, num_query_rounds: 28,
}, },
} }
} }
pub(crate) fn standard_recursion_zk_config() -> Self { pub fn standard_recursion_zk_config() -> Self {
CircuitConfig { CircuitConfig {
zero_knowledge: true, zero_knowledge: true,
..Self::standard_recursion_config() ..Self::standard_recursion_config()
@ -96,6 +105,13 @@ impl<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize> CircuitData<F
pub fn verify(&self, proof_with_pis: ProofWithPublicInputs<F, C, D>) -> Result<()> { pub fn verify(&self, proof_with_pis: ProofWithPublicInputs<F, C, D>) -> Result<()> {
verify(proof_with_pis, &self.verifier_only, &self.common) verify(proof_with_pis, &self.verifier_only, &self.common)
} }
pub fn verify_compressed(
&self,
compressed_proof_with_pis: CompressedProofWithPublicInputs<F, D>,
) -> Result<()> {
compressed_proof_with_pis.verify(&self.verifier_only, &self.common)
}
} }
/// Circuit data required by the prover. This may be thought of as a proving key, although it /// Circuit data required by the prover. This may be thought of as a proving key, although it
@ -132,6 +148,13 @@ impl<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize> VerifierCircu
pub fn verify(&self, proof_with_pis: ProofWithPublicInputs<F, C, D>) -> Result<()> { pub fn verify(&self, proof_with_pis: ProofWithPublicInputs<F, C, D>) -> Result<()> {
verify(proof_with_pis, &self.verifier_only, &self.common) verify(proof_with_pis, &self.verifier_only, &self.common)
} }
pub fn verify_compressed(
&self,
compressed_proof_with_pis: CompressedProofWithPublicInputs<F, D>,
) -> Result<()> {
compressed_proof_with_pis.verify(&self.verifier_only, &self.common)
}
} }
/// Circuit data required by the prover, but not the verifier. /// Circuit data required by the prover, but not the verifier.
@ -194,8 +217,8 @@ pub struct CommonCircuitData<F: Extendable<D>, C: GenericConfig<D, F = F>, const
/// The `{k_i}` valued used in `S_ID_i` in Plonk's permutation argument. /// The `{k_i}` valued used in `S_ID_i` in Plonk's permutation argument.
pub(crate) k_is: Vec<F>, pub(crate) k_is: Vec<F>,
/// The number of partial products needed to compute the `Z` polynomials and the number /// The number of partial products needed to compute the `Z` polynomials and
/// of partial products needed to compute the final product. /// the number of original elements consumed in `partial_products()`.
pub(crate) num_partial_products: (usize, usize), pub(crate) num_partial_products: (usize, usize),
/// A digest of the "circuit" (i.e. the instance, minus public inputs), which can be used to /// A digest of the "circuit" (i.e. the instance, minus public inputs), which can be used to
@ -228,11 +251,6 @@ impl<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize> CommonCircuit
self.quotient_degree_factor * self.degree() self.quotient_degree_factor * self.degree()
} }
pub fn total_constraints(&self) -> usize {
// 2 constraints for each Z check.
self.config.num_challenges * 2 + self.num_gate_constraints
}
/// Range of the constants polynomials in the `constants_sigmas_commitment`. /// Range of the constants polynomials in the `constants_sigmas_commitment`.
pub fn constants_range(&self) -> Range<usize> { pub fn constants_range(&self) -> Range<usize> {
0..self.num_constants 0..self.num_constants

View File

@ -11,7 +11,7 @@ use crate::plonk::proof::{
CompressedProof, CompressedProofWithPublicInputs, FriInferredElements, OpeningSet, Proof, CompressedProof, CompressedProofWithPublicInputs, FriInferredElements, OpeningSet, Proof,
ProofChallenges, ProofWithPublicInputs, ProofChallenges, ProofWithPublicInputs,
}; };
use crate::polynomial::polynomial::PolynomialCoeffs; use crate::polynomial::PolynomialCoeffs;
use crate::util::reverse_bits; use crate::util::reverse_bits;
fn get_challenges<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>( fn get_challenges<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>(

View File

@ -5,7 +5,7 @@ use rayon::prelude::*;
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::iop::target::Target; use crate::iop::target::Target;
use crate::iop::wire::Wire; use crate::iop::wire::Wire;
use crate::polynomial::polynomial::PolynomialValues; use crate::polynomial::PolynomialValues;
/// Disjoint Set Forest data-structure following https://en.wikipedia.org/wiki/Disjoint-set_data_structure. /// Disjoint Set Forest data-structure following https://en.wikipedia.org/wiki/Disjoint-set_data_structure.
pub struct Forest { pub struct Forest {
@ -45,15 +45,23 @@ impl Forest {
} }
/// Path compression method, see https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Finding_set_representatives. /// Path compression method, see https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Finding_set_representatives.
pub fn find(&mut self, x_index: usize) -> usize { pub fn find(&mut self, mut x_index: usize) -> usize {
let x_parent = self.parents[x_index]; // Note: We avoid recursion here since the chains can be long, causing stack overflows.
if x_parent != x_index {
let root_index = self.find(x_parent); // First, find the representative of the set containing `x_index`.
self.parents[x_index] = root_index; let mut representative = x_index;
root_index while self.parents[representative] != representative {
} else { representative = self.parents[representative];
x_index
} }
// Then, update each node in this chain to point directly to the representative.
while self.parents[x_index] != x_index {
let old_parent = self.parents[x_index];
self.parents[x_index] = representative;
x_index = old_parent;
}
representative
} }
/// Merge two sets. /// Merge two sets.

View File

@ -42,17 +42,6 @@ impl PlonkPolynomials {
index: 3, index: 3,
blinding: true, blinding: true,
}; };
#[cfg(test)]
pub fn polynomials(i: usize) -> PolynomialsIndexBlinding {
match i {
0 => Self::CONSTANTS_SIGMAS,
1 => Self::WIRES,
2 => Self::ZS_PARTIAL_PRODUCTS,
3 => Self::QUOTIENT,
_ => panic!("There are only 4 sets of polynomials in Plonk."),
}
}
} }
/// Evaluate the polynomial which vanishes on any multiplicative subgroup of a given order `n`. /// Evaluate the polynomial which vanishes on any multiplicative subgroup of a given order `n`.

View File

@ -164,12 +164,12 @@ impl<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
) -> anyhow::Result<ProofWithPublicInputs<F, C, D>> { ) -> anyhow::Result<ProofWithPublicInputs<F, C, D>> {
let challenges = self.get_challenges(common_data)?; let challenges = self.get_challenges(common_data)?;
let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data); let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data);
let compressed_proof = let decompressed_proof =
self.proof self.proof
.decompress(&challenges, fri_inferred_elements, common_data); .decompress(&challenges, fri_inferred_elements, common_data);
Ok(ProofWithPublicInputs { Ok(ProofWithPublicInputs {
public_inputs: self.public_inputs, public_inputs: self.public_inputs,
proof: compressed_proof, proof: decompressed_proof,
}) })
} }
@ -180,13 +180,13 @@ impl<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let challenges = self.get_challenges(common_data)?; let challenges = self.get_challenges(common_data)?;
let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data); let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data);
let compressed_proof = let decompressed_proof =
self.proof self.proof
.decompress(&challenges, fri_inferred_elements, common_data); .decompress(&challenges, fri_inferred_elements, common_data);
verify_with_challenges( verify_with_challenges(
ProofWithPublicInputs { ProofWithPublicInputs {
public_inputs: self.public_inputs, public_inputs: self.public_inputs,
proof: compressed_proof, proof: decompressed_proof,
}, },
challenges, challenges,
verifier_data, verifier_data,
@ -312,6 +312,7 @@ mod tests {
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::fri::reduction_strategies::FriReductionStrategy; use crate::fri::reduction_strategies::FriReductionStrategy;
use crate::gates::noop::NoopGate;
use crate::iop::witness::PartialWitness; use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::circuit_data::CircuitConfig;
@ -340,6 +341,9 @@ mod tests {
let zt = builder.constant(z); let zt = builder.constant(z);
let comp_zt = builder.mul(xt, yt); let comp_zt = builder.mul(xt, yt);
builder.connect(zt, comp_zt); builder.connect(zt, comp_zt);
for _ in 0..100 {
builder.add_gate(NoopGate, vec![]);
}
let data = builder.build::<C>(); let data = builder.build::<C>();
let proof = data.prove(pw)?; let proof = data.prove(pw)?;
verify(proof.clone(), &data.verifier_only, &data.common)?; verify(proof.clone(), &data.verifier_only, &data.common)?;
@ -350,6 +354,6 @@ mod tests {
assert_eq!(proof, decompressed_compressed_proof); assert_eq!(proof, decompressed_compressed_proof);
verify(proof, &data.verifier_only, &data.common)?; verify(proof, &data.verifier_only, &data.common)?;
compressed_proof.verify(&data.verifier_only, &data.common) data.verify_compressed(compressed_proof)
} }
} }

View File

@ -1,3 +1,5 @@
use std::mem::swap;
use anyhow::Result; use anyhow::Result;
use rayon::prelude::*; use rayon::prelude::*;
@ -13,9 +15,9 @@ use crate::plonk::plonk_common::ZeroPolyOnCoset;
use crate::plonk::proof::{Proof, ProofWithPublicInputs}; use crate::plonk::proof::{Proof, ProofWithPublicInputs};
use crate::plonk::vanishing_poly::eval_vanishing_poly_base_batch; use crate::plonk::vanishing_poly::eval_vanishing_poly_base_batch;
use crate::plonk::vars::EvaluationVarsBase; use crate::plonk::vars::EvaluationVarsBase;
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::timed; use crate::timed;
use crate::util::partial_products::partial_products; use crate::util::partial_products::{partial_products_and_z_gx, quotient_chunk_products};
use crate::util::timing::TimingTree; use crate::util::timing::TimingTree;
use crate::util::{log2_ceil, transpose}; use crate::util::{log2_ceil, transpose};
@ -89,28 +91,22 @@ pub(crate) fn prove<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize
common_data.quotient_degree_factor < common_data.config.num_routed_wires, common_data.quotient_degree_factor < common_data.config.num_routed_wires,
"When the number of routed wires is smaller that the degree, we should change the logic to avoid computing partial products." "When the number of routed wires is smaller that the degree, we should change the logic to avoid computing partial products."
); );
let mut partial_products = timed!( let mut partial_products_and_zs = timed!(
timing, timing,
"compute partial products", "compute partial products",
all_wires_permutation_partial_products(&witness, &betas, &gammas, prover_data, common_data) all_wires_permutation_partial_products(&witness, &betas, &gammas, prover_data, common_data)
); );
let plonk_z_vecs = timed!( // Z is expected at the front of our batch; see `zs_range` and `partial_products_range`.
timing, let plonk_z_vecs = partial_products_and_zs
"compute Z's", .iter_mut()
compute_zs(&partial_products, common_data) .map(|partial_products_and_z| partial_products_and_z.pop().unwrap())
); .collect();
let zs_partial_products = [plonk_z_vecs, partial_products_and_zs.concat()].concat();
// The first polynomial in `partial_products` represent the final product used in the let partial_products_and_zs_commitment = timed!(
// computation of `Z`. It isn't needed anymore so we discard it.
partial_products.iter_mut().for_each(|part| {
part.remove(0);
});
let zs_partial_products = [plonk_z_vecs, partial_products.concat()].concat();
let zs_partial_products_commitment = timed!(
timing, timing,
"commit to Z's", "commit to partial products and Z's",
PolynomialBatchCommitment::from_values( PolynomialBatchCommitment::from_values(
zs_partial_products, zs_partial_products,
config.rate_bits, config.rate_bits,
@ -121,7 +117,7 @@ pub(crate) fn prove<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize
) )
); );
challenger.observe_cap(&zs_partial_products_commitment.merkle_tree.cap); challenger.observe_cap(&partial_products_and_zs_commitment.merkle_tree.cap);
let alphas = challenger.get_n_challenges(num_challenges); let alphas = challenger.get_n_challenges(num_challenges);
@ -133,7 +129,7 @@ pub(crate) fn prove<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize
prover_data, prover_data,
&public_inputs_hash, &public_inputs_hash,
&wires_commitment, &wires_commitment,
&zs_partial_products_commitment, &partial_products_and_zs_commitment,
&betas, &betas,
&gammas, &gammas,
&alphas, &alphas,
@ -148,7 +144,6 @@ pub(crate) fn prove<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize
.into_par_iter() .into_par_iter()
.flat_map(|mut quotient_poly| { .flat_map(|mut quotient_poly| {
quotient_poly.trim(); quotient_poly.trim();
// TODO: Return Result instead of panicking.
quotient_poly.pad(quotient_degree).expect( quotient_poly.pad(quotient_degree).expect(
"Quotient has failed, the vanishing polynomial is not divisible by `Z_H", "Quotient has failed, the vanishing polynomial is not divisible by `Z_H",
); );
@ -182,7 +177,7 @@ pub(crate) fn prove<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize
&[ &[
&prover_data.constants_sigmas_commitment, &prover_data.constants_sigmas_commitment,
&wires_commitment, &wires_commitment,
&zs_partial_products_commitment, &partial_products_and_zs_commitment,
&quotient_polys_commitment, &quotient_polys_commitment,
], ],
zeta, zeta,
@ -194,7 +189,7 @@ pub(crate) fn prove<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize
let proof = Proof { let proof = Proof {
wires_cap: wires_commitment.merkle_tree.cap, wires_cap: wires_commitment.merkle_tree.cap,
plonk_zs_partial_products_cap: zs_partial_products_commitment.merkle_tree.cap, plonk_zs_partial_products_cap: partial_products_and_zs_commitment.merkle_tree.cap,
quotient_polys_cap: quotient_polys_commitment.merkle_tree.cap, quotient_polys_cap: quotient_polys_commitment.merkle_tree.cap,
openings, openings,
opening_proof, opening_proof,
@ -219,7 +214,7 @@ fn all_wires_permutation_partial_products<
) -> Vec<Vec<PolynomialValues<F>>> { ) -> Vec<Vec<PolynomialValues<F>>> {
(0..common_data.config.num_challenges) (0..common_data.config.num_challenges)
.map(|i| { .map(|i| {
wires_permutation_partial_products( wires_permutation_partial_products_and_zs(
witness, witness,
betas[i], betas[i],
gammas[i], gammas[i],
@ -233,7 +228,7 @@ fn all_wires_permutation_partial_products<
/// Compute the partial products used in the `Z` polynomial. /// Compute the partial products used in the `Z` polynomial.
/// Returns the polynomials interpolating `partial_products(f / g)` /// Returns the polynomials interpolating `partial_products(f / g)`
/// where `f, g` are the products in the definition of `Z`: `Z(g^i) = f / g`. /// where `f, g` are the products in the definition of `Z`: `Z(g^i) = f / g`.
fn wires_permutation_partial_products< fn wires_permutation_partial_products_and_zs<
F: Extendable<D>, F: Extendable<D>,
C: GenericConfig<D, F = F>, C: GenericConfig<D, F = F>,
const D: usize, const D: usize,
@ -247,7 +242,8 @@ fn wires_permutation_partial_products<
let degree = common_data.quotient_degree_factor; let degree = common_data.quotient_degree_factor;
let subgroup = &prover_data.subgroup; let subgroup = &prover_data.subgroup;
let k_is = &common_data.k_is; let k_is = &common_data.k_is;
let values = subgroup let (num_prods, _final_num_prod) = common_data.num_partial_products;
let all_quotient_chunk_products = subgroup
.par_iter() .par_iter()
.enumerate() .enumerate()
.map(|(i, &x)| { .map(|(i, &x)| {
@ -271,49 +267,26 @@ fn wires_permutation_partial_products<
.map(|(num, den_inv)| num * den_inv) .map(|(num, den_inv)| num * den_inv)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let quotient_partials = partial_products(&quotient_values, degree); quotient_chunk_products(&quotient_values, degree)
// This is the final product for the quotient.
let quotient = quotient_partials
[common_data.num_partial_products.0 - common_data.num_partial_products.1..]
.iter()
.copied()
.product();
// We add the quotient at the beginning of the vector to reuse them later in the computation of `Z`.
[vec![quotient], quotient_partials].concat()
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
transpose(&values) let mut z_x = F::ONE;
let mut all_partial_products_and_zs = Vec::new();
for quotient_chunk_products in all_quotient_chunk_products {
let mut partial_products_and_z_gx =
partial_products_and_z_gx(z_x, &quotient_chunk_products);
// The last term is Z(gx), but we replace it with Z(x), otherwise Z would end up shifted.
swap(&mut z_x, &mut partial_products_and_z_gx[num_prods]);
all_partial_products_and_zs.push(partial_products_and_z_gx);
}
transpose(&all_partial_products_and_zs)
.into_par_iter() .into_par_iter()
.map(PolynomialValues::new) .map(PolynomialValues::new)
.collect() .collect()
} }
fn compute_zs<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>(
partial_products: &[Vec<PolynomialValues<F>>],
common_data: &CommonCircuitData<F, C, D>,
) -> Vec<PolynomialValues<F>> {
(0..common_data.config.num_challenges)
.map(|i| compute_z(&partial_products[i], common_data))
.collect()
}
/// Compute the `Z` polynomial by reusing the computations done in `wires_permutation_partial_products`.
fn compute_z<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>(
partial_products: &[PolynomialValues<F>],
common_data: &CommonCircuitData<F, C, D>,
) -> PolynomialValues<F> {
let mut plonk_z_points = vec![F::ONE];
for i in 1..common_data.degree() {
let quotient = partial_products[0].values[i - 1];
let last = *plonk_z_points.last().unwrap();
plonk_z_points.push(last * quotient);
}
plonk_z_points.into()
}
const BATCH_SIZE: usize = 32; const BATCH_SIZE: usize = 32;
fn compute_quotient_polys<'a, F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>( fn compute_quotient_polys<'a, F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>(

View File

@ -133,6 +133,7 @@ mod tests {
use crate::fri::reduction_strategies::FriReductionStrategy; use crate::fri::reduction_strategies::FriReductionStrategy;
use crate::fri::FriConfig; use crate::fri::FriConfig;
use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; use crate::gadgets::polynomial::PolynomialCoeffsExtTarget;
use crate::gates::noop::NoopGate;
use crate::hash::merkle_proofs::MerkleProofTarget; use crate::hash::merkle_proofs::MerkleProofTarget;
use crate::iop::witness::{PartialWitness, Witness}; use crate::iop::witness::{PartialWitness, Witness};
use crate::plonk::circuit_data::VerifierOnlyCircuitData; use crate::plonk::circuit_data::VerifierOnlyCircuitData;
@ -369,9 +370,8 @@ mod tests {
type F = <C as GenericConfig<D>>::F; type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig::standard_recursion_config(); let config = CircuitConfig::standard_recursion_config();
let (proof, vd, cd) = dummy_proof::<F, C, D>(&config, 8_000)?; let (proof, vd, cd) = dummy_proof::<F,C, D>(&config, 4_000)?;
let (proof, _vd, cd) = let (proof, _vd, cd) = recursive_proof::<F,C,C,D>(proof, vd, cd, &config, &config, None, true, true)?;
recursive_proof::<F, C, C, D>(proof, vd, cd, &config, &config, true, true)?;
test_serialization(&proof, &cd)?; test_serialization(&proof, &cd)?;
Ok(()) Ok(())
@ -388,11 +388,14 @@ mod tests {
let config = CircuitConfig::standard_recursion_config(); let config = CircuitConfig::standard_recursion_config();
let (proof, vd, cd) = dummy_proof::<F, C, D>(&config, 8_000)?; // Start with a degree 2^14 proof, then shrink it to 2^13, then to 2^12.
let (proof, vd, cd) = dummy_proof::<F,C, D>(&config, 16_000)?;
assert_eq!(cd.degree_bits, 14);
let (proof, vd, cd) = let (proof, vd, cd) =
recursive_proof::<F, C, C, D>(proof, vd, cd, &config, &config, false, false)?; recursive_proof::<F,C,C,D>(proof, vd, cd, &config, &config, Some(13), false, false)?;
let (proof, _vd, cd) = assert_eq!(cd.degree_bits, 13);
recursive_proof::<F, KC, C, D>(proof, vd, cd, &config, &config, true, true)?; let (proof, _vd, cd) = recursive_proof::<F,KC,C,D>(proof, vd, cd, &config, &config, None, true, true)?;
assert_eq!(cd.degree_bits, 12);
test_serialization(&proof, &cd)?; test_serialization(&proof, &cd)?;
@ -412,29 +415,29 @@ mod tests {
let standard_config = CircuitConfig::standard_recursion_config(); let standard_config = CircuitConfig::standard_recursion_config();
// A dummy proof with degree 2^13. // An initial dummy proof.
let (proof, vd, cd) = dummy_proof::<F, C, D>(&standard_config, 8_000)?; let (proof, vd, cd) = dummy_proof::<F,C, D>(&standard_config, 4_000)?;
assert_eq!(cd.degree_bits, 13); assert_eq!(cd.degree_bits, 12);
// A standard recursive proof with degree 2^13. // A standard recursive proof.
let (proof, vd, cd) = recursive_proof( let (proof, vd, cd) = recursive_proof(
proof, proof,
vd, vd,
cd, cd,
&standard_config, &standard_config,
&standard_config, &standard_config,
None,
false, false,
false, false,
)?; )?;
assert_eq!(cd.degree_bits, 13); assert_eq!(cd.degree_bits, 12);
// A high-rate recursive proof with degree 2^13, designed to be verifiable with 2^12 // A high-rate recursive proof, designed to be verifiable with fewer routed wires.
// gates and 48 routed wires.
let high_rate_config = CircuitConfig { let high_rate_config = CircuitConfig {
rate_bits: 5, rate_bits: 7,
fri_config: FriConfig { fri_config: FriConfig {
proof_of_work_bits: 20, proof_of_work_bits: 16,
num_query_rounds: 16, num_query_rounds: 12,
..standard_config.fri_config.clone() ..standard_config.fri_config.clone()
}, },
..standard_config ..standard_config
@ -445,54 +448,35 @@ mod tests {
cd, cd,
&standard_config, &standard_config,
&high_rate_config, &high_rate_config,
true, None,
true,
)?;
assert_eq!(cd.degree_bits, 13);
// A higher-rate recursive proof with degree 2^12, designed to be verifiable with 2^12
// gates and 28 routed wires.
let higher_rate_more_routing_config = CircuitConfig {
rate_bits: 7,
num_routed_wires: 48,
fri_config: FriConfig {
proof_of_work_bits: 23,
num_query_rounds: 11,
..standard_config.fri_config.clone()
},
..high_rate_config.clone()
};
let (proof, vd, cd) = recursive_proof::<F, C, C, D>(
proof,
vd,
cd,
&high_rate_config,
&higher_rate_more_routing_config,
true, true,
true, true,
)?; )?;
assert_eq!(cd.degree_bits, 12); assert_eq!(cd.degree_bits, 12);
// A final proof of degree 2^12, optimized for size. // A final proof, optimized for size.
let final_config = CircuitConfig { let final_config = CircuitConfig {
cap_height: 0, cap_height: 0,
num_routed_wires: 32, rate_bits: 8,
num_routed_wires: 37,
fri_config: FriConfig { fri_config: FriConfig {
proof_of_work_bits: 20,
reduction_strategy: FriReductionStrategy::MinSize(None), reduction_strategy: FriReductionStrategy::MinSize(None),
..higher_rate_more_routing_config.fri_config.clone() num_query_rounds: 10,
}, },
..higher_rate_more_routing_config ..high_rate_config
}; };
let (proof, _vd, cd) = recursive_proof::<F, KC, C, D>( let (proof, _vd, cd) = recursive_proof::<F, KC, C, D>(
proof, proof,
vd, vd,
cd, cd,
&higher_rate_more_routing_config, &high_rate_config,
&final_config, &final_config,
None,
true, true,
true, true,
)?; )?;
assert_eq!(cd.degree_bits, 12); assert_eq!(cd.degree_bits, 12, "final proof too large");
test_serialization(&proof, &cd)?; test_serialization(&proof, &cd)?;
@ -509,16 +493,12 @@ mod tests {
CommonCircuitData<F, C, D>, CommonCircuitData<F, C, D>,
)> { )> {
let mut builder = CircuitBuilder::<F, D>::new(config.clone()); let mut builder = CircuitBuilder::<F, D>::new(config.clone());
let input = builder.add_virtual_target(); for _ in 0..num_dummy_gates {
for i in 0..num_dummy_gates { builder.add_gate(NoopGate, vec![]);
// Use unique constants to force a new `ArithmeticGate`.
let i_f = F::from_canonical_u64(i);
builder.arithmetic(i_f, i_f, input, input, input);
} }
let data = builder.build::<C>(); let data = builder.build::<C>();
let mut inputs = PartialWitness::new(); let inputs = PartialWitness::new();
inputs.set_target(input, F::ZERO);
let proof = data.prove(inputs)?; let proof = data.prove(inputs)?;
data.verify(proof.clone())?; data.verify(proof.clone())?;
@ -536,6 +516,7 @@ mod tests {
inner_cd: CommonCircuitData<F, InnerC, D>, inner_cd: CommonCircuitData<F, InnerC, D>,
inner_config: &CircuitConfig, inner_config: &CircuitConfig,
config: &CircuitConfig, config: &CircuitConfig,
min_degree_bits: Option<usize>,
print_gate_counts: bool, print_gate_counts: bool,
print_timing: bool, print_timing: bool,
) -> Result<( ) -> Result<(
@ -556,12 +537,22 @@ mod tests {
&inner_vd.constants_sigmas_cap, &inner_vd.constants_sigmas_cap,
); );
builder.add_recursive_verifier(pt, &inner_config, &inner_data, &inner_cd); builder.add_recursive_verifier(pt, inner_config, &inner_data, &inner_cd);
if print_gate_counts { if print_gate_counts {
builder.print_gate_counts(0); builder.print_gate_counts(0);
} }
if let Some(min_degree_bits) = min_degree_bits {
// We don't want to pad all the way up to 2^min_degree_bits, as the builder will add a
// few special gates afterward. So just pad to 2^(min_degree_bits - 1) + 1. Then the
// builder will pad to the next power of two, 2^min_degree_bits.
let min_gates = (1 << (min_degree_bits - 1)) + 1;
for _ in builder.num_gates()..min_gates {
builder.add_gate(NoopGate, vec![]);
}
}
let data = builder.build::<C>(); let data = builder.build::<C>();
let mut timing = TimingTree::new("prove", Level::Debug); let mut timing = TimingTree::new("prove", Level::Debug);
@ -582,12 +573,12 @@ mod tests {
) -> Result<()> { ) -> Result<()> {
let proof_bytes = proof.to_bytes()?; let proof_bytes = proof.to_bytes()?;
info!("Proof length: {} bytes", proof_bytes.len()); info!("Proof length: {} bytes", proof_bytes.len());
let proof_from_bytes = ProofWithPublicInputs::from_bytes(proof_bytes, &cd)?; let proof_from_bytes = ProofWithPublicInputs::from_bytes(proof_bytes, cd)?;
assert_eq!(proof, &proof_from_bytes); assert_eq!(proof, &proof_from_bytes);
let now = std::time::Instant::now(); let now = std::time::Instant::now();
let compressed_proof = proof.clone().compress(&cd)?; let compressed_proof = proof.clone().compress(cd)?;
let decompressed_compressed_proof = compressed_proof.clone().decompress(&cd)?; let decompressed_compressed_proof = compressed_proof.clone().decompress(cd)?;
info!("{:.4}s to compress proof", now.elapsed().as_secs_f64()); info!("{:.4}s to compress proof", now.elapsed().as_secs_f64());
assert_eq!(proof, &decompressed_compressed_proof); assert_eq!(proof, &decompressed_compressed_proof);
@ -597,7 +588,7 @@ mod tests {
compressed_proof_bytes.len() compressed_proof_bytes.len()
); );
let compressed_proof_from_bytes = let compressed_proof_from_bytes =
CompressedProofWithPublicInputs::from_bytes(compressed_proof_bytes, &cd)?; CompressedProofWithPublicInputs::from_bytes(compressed_proof_bytes, cd)?;
assert_eq!(compressed_proof, compressed_proof_from_bytes); assert_eq!(compressed_proof, compressed_proof_from_bytes);
Ok(()) Ok(())

View File

@ -29,7 +29,7 @@ pub(crate) fn eval_vanishing_poly<F: Extendable<D>, C: GenericConfig<D, F = F>,
alphas: &[F], alphas: &[F],
) -> Vec<F::Extension> { ) -> Vec<F::Extension> {
let max_degree = common_data.quotient_degree_factor; let max_degree = common_data.quotient_degree_factor;
let (num_prods, final_num_prod) = common_data.num_partial_products; let (num_prods, _final_num_prod) = common_data.num_partial_products;
let constraint_terms = let constraint_terms =
evaluate_gate_constraints(&common_data.gates, common_data.num_gate_constraints, vars); evaluate_gate_constraints(&common_data.gates, common_data.num_gate_constraints, vars);
@ -38,14 +38,12 @@ pub(crate) fn eval_vanishing_poly<F: Extendable<D>, C: GenericConfig<D, F = F>,
let mut vanishing_z_1_terms = Vec::new(); let mut vanishing_z_1_terms = Vec::new();
// The terms checking the partial products. // The terms checking the partial products.
let mut vanishing_partial_products_terms = Vec::new(); let mut vanishing_partial_products_terms = Vec::new();
// The Z(x) f'(x) - g'(x) Z(g x) terms.
let mut vanishing_v_shift_terms = Vec::new();
let l1_x = plonk_common::eval_l_1(common_data.degree(), x); let l1_x = plonk_common::eval_l_1(common_data.degree(), x);
for i in 0..common_data.config.num_challenges { for i in 0..common_data.config.num_challenges {
let z_x = local_zs[i]; let z_x = local_zs[i];
let z_gz = next_zs[i]; let z_gx = next_zs[i];
vanishing_z_1_terms.push(l1_x * (z_x - F::Extension::ONE)); vanishing_z_1_terms.push(l1_x * (z_x - F::Extension::ONE));
let numerator_values = (0..common_data.config.num_routed_wires) let numerator_values = (0..common_data.config.num_routed_wires)
@ -63,37 +61,24 @@ pub(crate) fn eval_vanishing_poly<F: Extendable<D>, C: GenericConfig<D, F = F>,
wire_value + s_sigma.scalar_mul(betas[i]) + gammas[i].into() wire_value + s_sigma.scalar_mul(betas[i]) + gammas[i].into()
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let quotient_values = (0..common_data.config.num_routed_wires)
.map(|j| numerator_values[j] / denominator_values[j])
.collect::<Vec<_>>();
// The partial products considered for this iteration of `i`. // The partial products considered for this iteration of `i`.
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
// Check the quotient partial products. // Check the quotient partial products.
let mut partial_product_check = let partial_product_checks = check_partial_products(
check_partial_products(&quotient_values, current_partial_products, max_degree); &numerator_values,
// The first checks are of the form `q - n/d` which is a rational function not a polynomial. &denominator_values,
// We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials. current_partial_products,
denominator_values z_x,
.chunks(max_degree) z_gx,
.zip(partial_product_check.iter_mut()) max_degree,
.for_each(|(d, q)| { );
*q *= d.iter().copied().product(); vanishing_partial_products_terms.extend(partial_product_checks);
});
vanishing_partial_products_terms.extend(partial_product_check);
// The quotient final product is the product of the last `final_num_prod` elements.
let quotient: F::Extension = current_partial_products[num_prods - final_num_prod..]
.iter()
.copied()
.product();
vanishing_v_shift_terms.push(quotient * z_x - z_gz);
} }
let vanishing_terms = [ let vanishing_terms = [
vanishing_z_1_terms, vanishing_z_1_terms,
vanishing_partial_products_terms, vanishing_partial_products_terms,
vanishing_v_shift_terms,
constraint_terms, constraint_terms,
] ]
.concat(); .concat();
@ -130,7 +115,7 @@ pub(crate) fn eval_vanishing_poly_base_batch<
assert_eq!(s_sigmas_batch.len(), n); assert_eq!(s_sigmas_batch.len(), n);
let max_degree = common_data.quotient_degree_factor; let max_degree = common_data.quotient_degree_factor;
let (num_prods, final_num_prod) = common_data.num_partial_products; let (num_prods, _final_num_prod) = common_data.num_partial_products;
let num_gate_constraints = common_data.num_gate_constraints; let num_gate_constraints = common_data.num_gate_constraints;
@ -143,14 +128,11 @@ pub(crate) fn eval_vanishing_poly_base_batch<
let mut numerator_values = Vec::with_capacity(num_routed_wires); let mut numerator_values = Vec::with_capacity(num_routed_wires);
let mut denominator_values = Vec::with_capacity(num_routed_wires); let mut denominator_values = Vec::with_capacity(num_routed_wires);
let mut quotient_values = Vec::with_capacity(num_routed_wires);
// The L_1(x) (Z(x) - 1) vanishing terms. // The L_1(x) (Z(x) - 1) vanishing terms.
let mut vanishing_z_1_terms = Vec::with_capacity(num_challenges); let mut vanishing_z_1_terms = Vec::with_capacity(num_challenges);
// The terms checking the partial products. // The terms checking the partial products.
let mut vanishing_partial_products_terms = Vec::new(); let mut vanishing_partial_products_terms = Vec::new();
// The Z(x) f'(x) - g'(x) Z(g x) terms.
let mut vanishing_v_shift_terms = Vec::with_capacity(num_challenges);
let mut res_batch: Vec<Vec<F>> = Vec::with_capacity(n); let mut res_batch: Vec<Vec<F>> = Vec::with_capacity(n);
for k in 0..n { for k in 0..n {
@ -168,7 +150,7 @@ pub(crate) fn eval_vanishing_poly_base_batch<
let l1_x = z_h_on_coset.eval_l1(index, x); let l1_x = z_h_on_coset.eval_l1(index, x);
for i in 0..num_challenges { for i in 0..num_challenges {
let z_x = local_zs[i]; let z_x = local_zs[i];
let z_gz = next_zs[i]; let z_gx = next_zs[i];
vanishing_z_1_terms.push(l1_x * z_x.sub_one()); vanishing_z_1_terms.push(l1_x * z_x.sub_one());
numerator_values.extend((0..num_routed_wires).map(|j| { numerator_values.extend((0..num_routed_wires).map(|j| {
@ -182,49 +164,33 @@ pub(crate) fn eval_vanishing_poly_base_batch<
let s_sigma = s_sigmas[j]; let s_sigma = s_sigmas[j];
wire_value + betas[i] * s_sigma + gammas[i] wire_value + betas[i] * s_sigma + gammas[i]
})); }));
let denominator_inverses = F::batch_multiplicative_inverse(&denominator_values);
quotient_values.extend(
(0..num_routed_wires).map(|j| numerator_values[j] * denominator_inverses[j]),
);
// The partial products considered for this iteration of `i`. // The partial products considered for this iteration of `i`.
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
// Check the numerator partial products. // Check the numerator partial products.
let mut partial_product_check = let partial_product_checks = check_partial_products(
check_partial_products(&quotient_values, current_partial_products, max_degree); &numerator_values,
// The first checks are of the form `q - n/d` which is a rational function not a polynomial. &denominator_values,
// We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials. current_partial_products,
denominator_values z_x,
.chunks(max_degree) z_gx,
.zip(partial_product_check.iter_mut()) max_degree,
.for_each(|(d, q)| { );
*q *= d.iter().copied().product(); vanishing_partial_products_terms.extend(partial_product_checks);
});
vanishing_partial_products_terms.extend(partial_product_check);
// The quotient final product is the product of the last `final_num_prod` elements.
let quotient: F = current_partial_products[num_prods - final_num_prod..]
.iter()
.copied()
.product();
vanishing_v_shift_terms.push(quotient * z_x - z_gz);
numerator_values.clear(); numerator_values.clear();
denominator_values.clear(); denominator_values.clear();
quotient_values.clear();
} }
let vanishing_terms = vanishing_z_1_terms let vanishing_terms = vanishing_z_1_terms
.iter() .iter()
.chain(vanishing_partial_products_terms.iter()) .chain(vanishing_partial_products_terms.iter())
.chain(vanishing_v_shift_terms.iter())
.chain(constraint_terms); .chain(constraint_terms);
let res = plonk_common::reduce_with_powers_multi(vanishing_terms, alphas); let res = plonk_common::reduce_with_powers_multi(vanishing_terms, alphas);
res_batch.push(res); res_batch.push(res);
vanishing_z_1_terms.clear(); vanishing_z_1_terms.clear();
vanishing_partial_products_terms.clear(); vanishing_partial_products_terms.clear();
vanishing_v_shift_terms.clear();
} }
res_batch res_batch
} }
@ -334,7 +300,7 @@ pub(crate) fn eval_vanishing_poly_recursively<
alphas: &[Target], alphas: &[Target],
) -> Vec<ExtensionTarget<D>> { ) -> Vec<ExtensionTarget<D>> {
let max_degree = common_data.quotient_degree_factor; let max_degree = common_data.quotient_degree_factor;
let (num_prods, final_num_prod) = common_data.num_partial_products; let (num_prods, _final_num_prod) = common_data.num_partial_products;
let constraint_terms = with_context!( let constraint_terms = with_context!(
builder, builder,
@ -351,8 +317,6 @@ pub(crate) fn eval_vanishing_poly_recursively<
let mut vanishing_z_1_terms = Vec::new(); let mut vanishing_z_1_terms = Vec::new();
// The terms checking the partial products. // The terms checking the partial products.
let mut vanishing_partial_products_terms = Vec::new(); let mut vanishing_partial_products_terms = Vec::new();
// The Z(x) f'(x) - g'(x) Z(g x) terms.
let mut vanishing_v_shift_terms = Vec::new();
let l1_x = eval_l_1_recursively(builder, common_data.degree(), x, x_pow_deg); let l1_x = eval_l_1_recursively(builder, common_data.degree(), x, x_pow_deg);
@ -365,14 +329,13 @@ pub(crate) fn eval_vanishing_poly_recursively<
for i in 0..common_data.config.num_challenges { for i in 0..common_data.config.num_challenges {
let z_x = local_zs[i]; let z_x = local_zs[i];
let z_gz = next_zs[i]; let z_gx = next_zs[i];
// L_1(x) Z(x) = 0. // L_1(x) Z(x) = 0.
vanishing_z_1_terms.push(builder.mul_sub_extension(l1_x, z_x, l1_x)); vanishing_z_1_terms.push(builder.mul_sub_extension(l1_x, z_x, l1_x));
let mut numerator_values = Vec::new(); let mut numerator_values = Vec::new();
let mut denominator_values = Vec::new(); let mut denominator_values = Vec::new();
let mut quotient_values = Vec::new();
for j in 0..common_data.config.num_routed_wires { for j in 0..common_data.config.num_routed_wires {
let wire_value = vars.local_wires[j]; let wire_value = vars.local_wires[j];
@ -385,44 +348,28 @@ pub(crate) fn eval_vanishing_poly_recursively<
let numerator = builder.mul_add_extension(beta_ext, s_ids[j], wire_value_plus_gamma); let numerator = builder.mul_add_extension(beta_ext, s_ids[j], wire_value_plus_gamma);
let denominator = let denominator =
builder.mul_add_extension(beta_ext, s_sigmas[j], wire_value_plus_gamma); builder.mul_add_extension(beta_ext, s_sigmas[j], wire_value_plus_gamma);
let quotient = builder.div_extension(numerator, denominator);
numerator_values.push(numerator); numerator_values.push(numerator);
denominator_values.push(denominator); denominator_values.push(denominator);
quotient_values.push(quotient);
} }
// The partial products considered for this iteration of `i`. // The partial products considered for this iteration of `i`.
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
// Check the quotient partial products. // Check the quotient partial products.
let mut partial_product_check = check_partial_products_recursively( let partial_product_checks = check_partial_products_recursively(
builder, builder,
&quotient_values, &numerator_values,
&denominator_values,
current_partial_products, current_partial_products,
z_x,
z_gx,
max_degree, max_degree,
); );
// The first checks are of the form `q - n/d` which is a rational function not a polynomial. vanishing_partial_products_terms.extend(partial_product_checks);
// We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials.
denominator_values
.chunks(max_degree)
.zip(partial_product_check.iter_mut())
.for_each(|(d, q)| {
let mut v = d.to_vec();
v.push(*q);
*q = builder.mul_many_extension(&v);
});
vanishing_partial_products_terms.extend(partial_product_check);
// The quotient final product is the product of the last `final_num_prod` elements.
let quotient =
builder.mul_many_extension(&current_partial_products[num_prods - final_num_prod..]);
vanishing_v_shift_terms.push(builder.mul_sub_extension(quotient, z_x, z_gz));
} }
let vanishing_terms = [ let vanishing_terms = [
vanishing_z_1_terms, vanishing_z_1_terms,
vanishing_partial_products_terms, vanishing_partial_products_terms,
vanishing_v_shift_terms,
constraint_terms, constraint_terms,
] ]
.concat(); .concat();

View File

@ -1,4 +1,3 @@
use std::convert::TryInto;
use std::ops::Range; use std::ops::Range;
use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::algebra::ExtensionAlgebra;

View File

@ -1,7 +1,6 @@
use crate::field::fft::{fft, ifft};
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::polynomial::polynomial::PolynomialCoeffs; use crate::polynomial::PolynomialCoeffs;
use crate::util::{log2_ceil, log2_strict}; use crate::util::log2_ceil;
impl<F: Field> PolynomialCoeffs<F> { impl<F: Field> PolynomialCoeffs<F> {
/// Polynomial division. /// Polynomial division.
@ -67,63 +66,6 @@ impl<F: Field> PolynomialCoeffs<F> {
} }
} }
/// Takes a polynomial `a` in coefficient form, and divides it by `Z_H = X^n - 1`.
///
/// This assumes `Z_H | a`, otherwise result is meaningless.
pub(crate) fn divide_by_z_h(&self, n: usize) -> PolynomialCoeffs<F> {
let mut a = self.clone();
// TODO: Is this special case needed?
if a.coeffs.iter().all(|p| *p == F::ZERO) {
return a;
}
let g = F::MULTIPLICATIVE_GROUP_GENERATOR;
let mut g_pow = F::ONE;
// Multiply the i-th coefficient of `a` by `g^i`. Then `new_a(w^j) = old_a(g.w^j)`.
a.coeffs.iter_mut().for_each(|x| {
*x *= g_pow;
g_pow *= g;
});
let root = F::primitive_root_of_unity(log2_strict(a.len()));
// Equals to the evaluation of `a` on `{g.w^i}`.
let mut a_eval = fft(&a);
// Compute the denominators `1/(g^n.w^(n*i) - 1)` using batch inversion.
let denominator_g = g.exp_u64(n as u64);
let root_n = root.exp_u64(n as u64);
let mut root_pow = F::ONE;
let denominators = (0..a_eval.len())
.map(|i| {
if i != 0 {
root_pow *= root_n;
}
denominator_g * root_pow - F::ONE
})
.collect::<Vec<_>>();
let denominators_inv = F::batch_multiplicative_inverse(&denominators);
// Divide every element of `a_eval` by the corresponding denominator.
// Then, `a_eval` is the evaluation of `a/Z_H` on `{g.w^i}`.
a_eval
.values
.iter_mut()
.zip(denominators_inv.iter())
.for_each(|(x, &d)| {
*x *= d;
});
// `p` is the interpolating polynomial of `a_eval` on `{w^i}`.
let mut p = ifft(&a_eval);
// We need to scale it by `g^(-i)` to get the interpolating polynomial of `a_eval` on `{g.w^i}`,
// a.k.a `a/Z_H`.
let g_inv = g.inverse();
let mut g_inv_pow = F::ONE;
p.coeffs.iter_mut().for_each(|x| {
*x *= g_inv_pow;
g_inv_pow *= g_inv;
});
p
}
/// Let `self=p(X)`, this returns `(p(X)-p(z))/(X-z)` and `p(z)`. /// Let `self=p(X)`, this returns `(p(X)-p(z))/(X-z)` and `p(z)`.
/// See https://en.wikipedia.org/wiki/Horner%27s_method /// See https://en.wikipedia.org/wiki/Horner%27s_method
pub(crate) fn divide_by_linear(&self, z: F) -> (PolynomialCoeffs<F>, F) { pub(crate) fn divide_by_linear(&self, z: F) -> (PolynomialCoeffs<F>, F) {
@ -187,35 +129,7 @@ mod tests {
use crate::field::extension_field::quartic::QuarticExtension; use crate::field::extension_field::quartic::QuarticExtension;
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField; use crate::field::goldilocks_field::GoldilocksField;
use crate::polynomial::polynomial::PolynomialCoeffs; use crate::polynomial::PolynomialCoeffs;
#[test]
fn zero_div_z_h() {
type F = GoldilocksField;
let zero = PolynomialCoeffs::<F>::zero(16);
let quotient = zero.divide_by_z_h(4);
assert_eq!(quotient, zero);
}
#[test]
fn division_by_z_h() {
type F = GoldilocksField;
let zero = F::ZERO;
let three = F::from_canonical_u64(3);
let four = F::from_canonical_u64(4);
let five = F::from_canonical_u64(5);
let six = F::from_canonical_u64(6);
// a(x) = Z_4(x) q(x), where
// a(x) = 3 x^7 + 4 x^6 + 5 x^5 + 6 x^4 - 3 x^3 - 4 x^2 - 5 x - 6
// Z_4(x) = x^4 - 1
// q(x) = 3 x^3 + 4 x^2 + 5 x + 6
let a = PolynomialCoeffs::new(vec![-six, -five, -four, -three, six, five, four, three]);
let q = PolynomialCoeffs::new(vec![six, five, four, three, zero, zero, zero, zero]);
let computed_q = a.divide_by_z_h(4);
assert_eq!(computed_q, q);
}
#[test] #[test]
#[ignore] #[ignore]

View File

@ -1,2 +1,616 @@
pub(crate) mod division; pub(crate) mod division;
pub mod polynomial;
use std::cmp::max;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign};
use anyhow::{ensure, Result};
use serde::{Deserialize, Serialize};
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::fft::{fft, fft_with_options, ifft, FftRootTable};
use crate::field::field_types::Field;
use crate::util::log2_strict;
/// A polynomial in point-value form.
///
/// The points are implicitly `g^i`, where `g` generates the subgroup whose size equals the number
/// of points.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct PolynomialValues<F: Field> {
pub values: Vec<F>,
}
impl<F: Field> PolynomialValues<F> {
pub fn new(values: Vec<F>) -> Self {
PolynomialValues { values }
}
/// The number of values stored.
pub(crate) fn len(&self) -> usize {
self.values.len()
}
pub fn ifft(&self) -> PolynomialCoeffs<F> {
ifft(self)
}
/// Returns the polynomial whose evaluation on the coset `shift*H` is `self`.
pub fn coset_ifft(&self, shift: F) -> PolynomialCoeffs<F> {
let mut shifted_coeffs = self.ifft();
shifted_coeffs
.coeffs
.iter_mut()
.zip(shift.inverse().powers())
.for_each(|(c, r)| {
*c *= r;
});
shifted_coeffs
}
pub fn lde_multiple(polys: Vec<Self>, rate_bits: usize) -> Vec<Self> {
polys.into_iter().map(|p| p.lde(rate_bits)).collect()
}
pub fn lde(&self, rate_bits: usize) -> Self {
let coeffs = ifft(self).lde(rate_bits);
fft_with_options(&coeffs, Some(rate_bits), None)
}
pub fn degree(&self) -> usize {
self.degree_plus_one()
.checked_sub(1)
.expect("deg(0) is undefined")
}
pub fn degree_plus_one(&self) -> usize {
self.ifft().degree_plus_one()
}
}
impl<F: Field> From<Vec<F>> for PolynomialValues<F> {
fn from(values: Vec<F>) -> Self {
Self::new(values)
}
}
/// A polynomial in coefficient form.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound = "")]
pub struct PolynomialCoeffs<F: Field> {
pub(crate) coeffs: Vec<F>,
}
impl<F: Field> PolynomialCoeffs<F> {
pub fn new(coeffs: Vec<F>) -> Self {
PolynomialCoeffs { coeffs }
}
pub(crate) fn empty() -> Self {
Self::new(Vec::new())
}
pub(crate) fn zero(len: usize) -> Self {
Self::new(vec![F::ZERO; len])
}
pub(crate) fn is_zero(&self) -> bool {
self.coeffs.iter().all(|x| x.is_zero())
}
/// The number of coefficients. This does not filter out any zero coefficients, so it is not
/// necessarily related to the degree.
pub fn len(&self) -> usize {
self.coeffs.len()
}
pub fn log_len(&self) -> usize {
log2_strict(self.len())
}
pub(crate) fn chunks(&self, chunk_size: usize) -> Vec<Self> {
self.coeffs
.chunks(chunk_size)
.map(|chunk| PolynomialCoeffs::new(chunk.to_vec()))
.collect()
}
pub fn eval(&self, x: F) -> F {
self.coeffs
.iter()
.rev()
.fold(F::ZERO, |acc, &c| acc * x + c)
}
/// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1.
pub fn eval_with_powers(&self, powers: &[F]) -> F {
debug_assert_eq!(self.coeffs.len(), powers.len() + 1);
let acc = self.coeffs[0];
self.coeffs[1..]
.iter()
.zip(powers)
.fold(acc, |acc, (&x, &c)| acc + c * x)
}
pub fn eval_base<const D: usize>(&self, x: F::BaseField) -> F
where
F: FieldExtension<D>,
{
self.coeffs
.iter()
.rev()
.fold(F::ZERO, |acc, &c| acc.scalar_mul(x) + c)
}
/// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1.
pub fn eval_base_with_powers<const D: usize>(&self, powers: &[F::BaseField]) -> F
where
F: FieldExtension<D>,
{
debug_assert_eq!(self.coeffs.len(), powers.len() + 1);
let acc = self.coeffs[0];
self.coeffs[1..]
.iter()
.zip(powers)
.fold(acc, |acc, (&x, &c)| acc + x.scalar_mul(c))
}
pub fn lde_multiple(polys: Vec<&Self>, rate_bits: usize) -> Vec<Self> {
polys.into_iter().map(|p| p.lde(rate_bits)).collect()
}
pub fn lde(&self, rate_bits: usize) -> Self {
self.padded(self.len() << rate_bits)
}
pub(crate) fn pad(&mut self, new_len: usize) -> Result<()> {
ensure!(
new_len >= self.len(),
"Trying to pad a polynomial of length {} to a length of {}.",
self.len(),
new_len
);
self.coeffs.resize(new_len, F::ZERO);
Ok(())
}
pub(crate) fn padded(&self, new_len: usize) -> Self {
let mut poly = self.clone();
poly.pad(new_len).unwrap();
poly
}
/// Removes leading zero coefficients.
pub fn trim(&mut self) {
self.coeffs.truncate(self.degree_plus_one());
}
/// Removes leading zero coefficients.
pub fn trimmed(&self) -> Self {
let coeffs = self.coeffs[..self.degree_plus_one()].to_vec();
Self { coeffs }
}
/// Degree of the polynomial + 1, or 0 for a polynomial with no non-zero coefficients.
pub(crate) fn degree_plus_one(&self) -> usize {
(0usize..self.len())
.rev()
.find(|&i| self.coeffs[i].is_nonzero())
.map_or(0, |i| i + 1)
}
/// Leading coefficient.
pub fn lead(&self) -> F {
self.coeffs
.iter()
.rev()
.find(|x| x.is_nonzero())
.map_or(F::ZERO, |x| *x)
}
/// Reverse the order of the coefficients, not taking into account the leading zero coefficients.
pub(crate) fn rev(&self) -> Self {
Self::new(self.trimmed().coeffs.into_iter().rev().collect())
}
pub fn fft(&self) -> PolynomialValues<F> {
fft(self)
}
pub fn fft_with_options(
&self,
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) -> PolynomialValues<F> {
fft_with_options(self, zero_factor, root_table)
}
/// Returns the evaluation of the polynomial on the coset `shift*H`.
pub fn coset_fft(&self, shift: F) -> PolynomialValues<F> {
self.coset_fft_with_options(shift, None, None)
}
/// Returns the evaluation of the polynomial on the coset `shift*H`.
pub fn coset_fft_with_options(
&self,
shift: F,
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) -> PolynomialValues<F> {
let modified_poly: Self = shift
.powers()
.zip(&self.coeffs)
.map(|(r, &c)| r * c)
.collect::<Vec<_>>()
.into();
modified_poly.fft_with_options(zero_factor, root_table)
}
pub fn to_extension<const D: usize>(&self) -> PolynomialCoeffs<F::Extension>
where
F: Extendable<D>,
{
PolynomialCoeffs::new(self.coeffs.iter().map(|&c| c.into()).collect())
}
pub fn mul_extension<const D: usize>(&self, rhs: F::Extension) -> PolynomialCoeffs<F::Extension>
where
F: Extendable<D>,
{
PolynomialCoeffs::new(self.coeffs.iter().map(|&c| rhs.scalar_mul(c)).collect())
}
}
impl<F: Field> PartialEq for PolynomialCoeffs<F> {
fn eq(&self, other: &Self) -> bool {
let max_terms = self.coeffs.len().max(other.coeffs.len());
for i in 0..max_terms {
let self_i = self.coeffs.get(i).cloned().unwrap_or(F::ZERO);
let other_i = other.coeffs.get(i).cloned().unwrap_or(F::ZERO);
if self_i != other_i {
return false;
}
}
true
}
}
impl<F: Field> Eq for PolynomialCoeffs<F> {}
impl<F: Field> From<Vec<F>> for PolynomialCoeffs<F> {
fn from(coeffs: Vec<F>) -> Self {
Self::new(coeffs)
}
}
impl<F: Field> Add for &PolynomialCoeffs<F> {
type Output = PolynomialCoeffs<F>;
fn add(self, rhs: Self) -> Self::Output {
let len = max(self.len(), rhs.len());
let a = self.padded(len).coeffs;
let b = rhs.padded(len).coeffs;
let coeffs = a.into_iter().zip(b).map(|(x, y)| x + y).collect();
PolynomialCoeffs::new(coeffs)
}
}
impl<F: Field> Sum for PolynomialCoeffs<F> {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Self::empty(), |acc, p| &acc + &p)
}
}
impl<F: Field> Sub for &PolynomialCoeffs<F> {
type Output = PolynomialCoeffs<F>;
fn sub(self, rhs: Self) -> Self::Output {
let len = max(self.len(), rhs.len());
let mut coeffs = self.padded(len).coeffs;
for (i, &c) in rhs.coeffs.iter().enumerate() {
coeffs[i] -= c;
}
PolynomialCoeffs::new(coeffs)
}
}
impl<F: Field> AddAssign for PolynomialCoeffs<F> {
fn add_assign(&mut self, rhs: Self) {
let len = max(self.len(), rhs.len());
self.coeffs.resize(len, F::ZERO);
for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) {
*l += r;
}
}
}
impl<F: Field> AddAssign<&Self> for PolynomialCoeffs<F> {
fn add_assign(&mut self, rhs: &Self) {
let len = max(self.len(), rhs.len());
self.coeffs.resize(len, F::ZERO);
for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) {
*l += r;
}
}
}
impl<F: Field> SubAssign for PolynomialCoeffs<F> {
fn sub_assign(&mut self, rhs: Self) {
let len = max(self.len(), rhs.len());
self.coeffs.resize(len, F::ZERO);
for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) {
*l -= r;
}
}
}
impl<F: Field> SubAssign<&Self> for PolynomialCoeffs<F> {
fn sub_assign(&mut self, rhs: &Self) {
let len = max(self.len(), rhs.len());
self.coeffs.resize(len, F::ZERO);
for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) {
*l -= r;
}
}
}
impl<F: Field> Mul<F> for &PolynomialCoeffs<F> {
type Output = PolynomialCoeffs<F>;
fn mul(self, rhs: F) -> Self::Output {
let coeffs = self.coeffs.iter().map(|&x| rhs * x).collect();
PolynomialCoeffs::new(coeffs)
}
}
impl<F: Field> MulAssign<F> for PolynomialCoeffs<F> {
fn mul_assign(&mut self, rhs: F) {
self.coeffs.iter_mut().for_each(|x| *x *= rhs);
}
}
impl<F: Field> Mul for &PolynomialCoeffs<F> {
type Output = PolynomialCoeffs<F>;
#[allow(clippy::suspicious_arithmetic_impl)]
fn mul(self, rhs: Self) -> Self::Output {
let new_len = (self.len() + rhs.len()).next_power_of_two();
let a = self.padded(new_len);
let b = rhs.padded(new_len);
let a_evals = a.fft();
let b_evals = b.fft();
let mul_evals: Vec<F> = a_evals
.values
.into_iter()
.zip(b_evals.values)
.map(|(pa, pb)| pa * pb)
.collect();
ifft(&mul_evals.into())
}
}
#[cfg(test)]
mod tests {
use std::time::Instant;
use rand::{thread_rng, Rng};
use super::*;
use crate::field::goldilocks_field::GoldilocksField;
#[test]
fn test_trimmed() {
type F = GoldilocksField;
assert_eq!(
PolynomialCoeffs::<F> { coeffs: vec![] }.trimmed(),
PolynomialCoeffs::<F> { coeffs: vec![] }
);
assert_eq!(
PolynomialCoeffs::<F> {
coeffs: vec![F::ZERO]
}
.trimmed(),
PolynomialCoeffs::<F> { coeffs: vec![] }
);
assert_eq!(
PolynomialCoeffs::<F> {
coeffs: vec![F::ONE, F::TWO, F::ZERO, F::ZERO]
}
.trimmed(),
PolynomialCoeffs::<F> {
coeffs: vec![F::ONE, F::TWO]
}
);
}
#[test]
fn test_coset_fft() {
type F = GoldilocksField;
let k = 8;
let n = 1 << k;
let poly = PolynomialCoeffs::new(F::rand_vec(n));
let shift = F::rand();
let coset_evals = poly.coset_fft(shift).values;
let generator = F::primitive_root_of_unity(k);
let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n)
.into_iter()
.map(|x| poly.eval(x))
.collect::<Vec<_>>();
assert_eq!(coset_evals, naive_coset_evals);
let ifft_coeffs = PolynomialValues::new(coset_evals).coset_ifft(shift);
assert_eq!(poly, ifft_coeffs);
}
#[test]
fn test_coset_ifft() {
type F = GoldilocksField;
let k = 8;
let n = 1 << k;
let evals = PolynomialValues::new(F::rand_vec(n));
let shift = F::rand();
let coeffs = evals.coset_ifft(shift);
let generator = F::primitive_root_of_unity(k);
let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n)
.into_iter()
.map(|x| coeffs.eval(x))
.collect::<Vec<_>>();
assert_eq!(evals, naive_coset_evals.into());
let fft_evals = coeffs.coset_fft(shift);
assert_eq!(evals, fft_evals);
}
#[test]
fn test_polynomial_multiplication() {
type F = GoldilocksField;
let mut rng = thread_rng();
let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000));
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = PolynomialCoeffs::new(F::rand_vec(b_deg));
let m1 = &a * &b;
let m2 = &a * &b;
for _ in 0..1000 {
let x = F::rand();
assert_eq!(m1.eval(x), a.eval(x) * b.eval(x));
assert_eq!(m2.eval(x), a.eval(x) * b.eval(x));
}
}
#[test]
fn test_inv_mod_xn() {
type F = GoldilocksField;
let mut rng = thread_rng();
let a_deg = rng.gen_range(1..1_000);
let n = rng.gen_range(1..1_000);
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = a.inv_mod_xn(n);
let mut m = &a * &b;
m.coeffs.drain(n..);
m.trim();
assert_eq!(
m,
PolynomialCoeffs::new(vec![F::ONE]),
"a: {:#?}, b:{:#?}, n:{:#?}, m:{:#?}",
a,
b,
n,
m
);
}
#[test]
fn test_polynomial_long_division() {
type F = GoldilocksField;
let mut rng = thread_rng();
let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000));
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = PolynomialCoeffs::new(F::rand_vec(b_deg));
let (q, r) = a.div_rem_long_division(&b);
for _ in 0..1000 {
let x = F::rand();
assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x));
}
}
#[test]
fn test_polynomial_division() {
type F = GoldilocksField;
let mut rng = thread_rng();
let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000));
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = PolynomialCoeffs::new(F::rand_vec(b_deg));
let (q, r) = a.div_rem(&b);
for _ in 0..1000 {
let x = F::rand();
assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x));
}
}
#[test]
fn test_polynomial_division_by_constant() {
type F = GoldilocksField;
let mut rng = thread_rng();
let a_deg = rng.gen_range(1..10_000);
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = PolynomialCoeffs::from(vec![F::rand()]);
let (q, r) = a.div_rem(&b);
for _ in 0..1000 {
let x = F::rand();
assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x));
}
}
// Test to see which polynomial division method is faster for divisions of the type
// `(X^n - 1)/(X - a)
#[test]
fn test_division_linear() {
type F = GoldilocksField;
let mut rng = thread_rng();
let l = 14;
let n = 1 << l;
let g = F::primitive_root_of_unity(l);
let xn_minus_one = {
let mut xn_min_one_vec = vec![F::ZERO; n + 1];
xn_min_one_vec[n] = F::ONE;
xn_min_one_vec[0] = F::NEG_ONE;
PolynomialCoeffs::new(xn_min_one_vec)
};
let a = g.exp_u64(rng.gen_range(0..(n as u64)));
let denom = PolynomialCoeffs::new(vec![-a, F::ONE]);
let now = Instant::now();
xn_minus_one.div_rem(&denom);
println!("Division time: {:?}", now.elapsed());
let now = Instant::now();
xn_minus_one.div_rem_long_division(&denom);
println!("Division time: {:?}", now.elapsed());
}
#[test]
fn eq() {
type F = GoldilocksField;
assert_eq!(
PolynomialCoeffs::<F>::new(vec![]),
PolynomialCoeffs::new(vec![])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![F::ZERO])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![]),
PolynomialCoeffs::new(vec![F::ZERO])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![F::ZERO, F::ZERO])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![F::ONE]),
PolynomialCoeffs::new(vec![F::ONE, F::ZERO])
);
assert_ne!(
PolynomialCoeffs::<F>::new(vec![]),
PolynomialCoeffs::new(vec![F::ONE])
);
assert_ne!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![F::ZERO, F::ONE])
);
assert_ne!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![F::ONE, F::ZERO])
);
}
}

View File

@ -1,635 +0,0 @@
use std::cmp::max;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign};
use anyhow::{ensure, Result};
use serde::{Deserialize, Serialize};
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::fft::{fft, fft_with_options, ifft, FftRootTable};
use crate::field::field_types::Field;
use crate::util::log2_strict;
/// A polynomial in point-value form.
///
/// The points are implicitly `g^i`, where `g` generates the subgroup whose size equals the number
/// of points.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct PolynomialValues<F: Field> {
pub values: Vec<F>,
}
impl<F: Field> PolynomialValues<F> {
pub fn new(values: Vec<F>) -> Self {
PolynomialValues { values }
}
pub(crate) fn zero(len: usize) -> Self {
Self::new(vec![F::ZERO; len])
}
/// The number of values stored.
pub(crate) fn len(&self) -> usize {
self.values.len()
}
pub fn ifft(&self) -> PolynomialCoeffs<F> {
ifft(self)
}
/// Returns the polynomial whose evaluation on the coset `shift*H` is `self`.
pub fn coset_ifft(&self, shift: F) -> PolynomialCoeffs<F> {
let mut shifted_coeffs = self.ifft();
shifted_coeffs
.coeffs
.iter_mut()
.zip(shift.inverse().powers())
.for_each(|(c, r)| {
*c *= r;
});
shifted_coeffs
}
pub fn lde_multiple(polys: Vec<Self>, rate_bits: usize) -> Vec<Self> {
polys.into_iter().map(|p| p.lde(rate_bits)).collect()
}
pub fn lde(&self, rate_bits: usize) -> Self {
let coeffs = ifft(self).lde(rate_bits);
fft_with_options(&coeffs, Some(rate_bits), None)
}
pub fn degree(&self) -> usize {
self.degree_plus_one()
.checked_sub(1)
.expect("deg(0) is undefined")
}
pub fn degree_plus_one(&self) -> usize {
self.ifft().degree_plus_one()
}
}
impl<F: Field> From<Vec<F>> for PolynomialValues<F> {
fn from(values: Vec<F>) -> Self {
Self::new(values)
}
}
/// A polynomial in coefficient form.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound = "")]
pub struct PolynomialCoeffs<F: Field> {
pub(crate) coeffs: Vec<F>,
}
impl<F: Field> PolynomialCoeffs<F> {
pub fn new(coeffs: Vec<F>) -> Self {
PolynomialCoeffs { coeffs }
}
/// Create a new polynomial with its coefficient list padded to the next power of two.
pub(crate) fn new_padded(mut coeffs: Vec<F>) -> Self {
while !coeffs.len().is_power_of_two() {
coeffs.push(F::ZERO);
}
PolynomialCoeffs { coeffs }
}
pub(crate) fn empty() -> Self {
Self::new(Vec::new())
}
pub(crate) fn zero(len: usize) -> Self {
Self::new(vec![F::ZERO; len])
}
pub(crate) fn one() -> Self {
Self::new(vec![F::ONE])
}
pub(crate) fn is_zero(&self) -> bool {
self.coeffs.iter().all(|x| x.is_zero())
}
/// The number of coefficients. This does not filter out any zero coefficients, so it is not
/// necessarily related to the degree.
pub fn len(&self) -> usize {
self.coeffs.len()
}
pub fn log_len(&self) -> usize {
log2_strict(self.len())
}
pub(crate) fn chunks(&self, chunk_size: usize) -> Vec<Self> {
self.coeffs
.chunks(chunk_size)
.map(|chunk| PolynomialCoeffs::new(chunk.to_vec()))
.collect()
}
pub fn eval(&self, x: F) -> F {
self.coeffs
.iter()
.rev()
.fold(F::ZERO, |acc, &c| acc * x + c)
}
pub fn eval_base<const D: usize>(&self, x: F::BaseField) -> F
where
F: FieldExtension<D>,
{
self.coeffs
.iter()
.rev()
.fold(F::ZERO, |acc, &c| acc.scalar_mul(x) + c)
}
pub fn lde_multiple(polys: Vec<&Self>, rate_bits: usize) -> Vec<Self> {
polys.into_iter().map(|p| p.lde(rate_bits)).collect()
}
pub fn lde(&self, rate_bits: usize) -> Self {
self.padded(self.len() << rate_bits)
}
pub(crate) fn pad(&mut self, new_len: usize) -> Result<()> {
ensure!(
new_len >= self.len(),
"Trying to pad a polynomial of length {} to a length of {}.",
self.len(),
new_len
);
self.coeffs.resize(new_len, F::ZERO);
Ok(())
}
pub(crate) fn padded(&self, new_len: usize) -> Self {
let mut poly = self.clone();
poly.pad(new_len).unwrap();
poly
}
/// Removes leading zero coefficients.
pub fn trim(&mut self) {
self.coeffs.truncate(self.degree_plus_one());
}
/// Removes leading zero coefficients.
pub fn trimmed(&self) -> Self {
let coeffs = self.coeffs[..self.degree_plus_one()].to_vec();
Self { coeffs }
}
/// Degree of the polynomial + 1, or 0 for a polynomial with no non-zero coefficients.
pub(crate) fn degree_plus_one(&self) -> usize {
(0usize..self.len())
.rev()
.find(|&i| self.coeffs[i].is_nonzero())
.map_or(0, |i| i + 1)
}
/// Leading coefficient.
pub fn lead(&self) -> F {
self.coeffs
.iter()
.rev()
.find(|x| x.is_nonzero())
.map_or(F::ZERO, |x| *x)
}
/// Reverse the order of the coefficients, not taking into account the leading zero coefficients.
pub(crate) fn rev(&self) -> Self {
Self::new(self.trimmed().coeffs.into_iter().rev().collect())
}
pub fn fft(&self) -> PolynomialValues<F> {
fft(self)
}
pub fn fft_with_options(
&self,
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) -> PolynomialValues<F> {
fft_with_options(self, zero_factor, root_table)
}
/// Returns the evaluation of the polynomial on the coset `shift*H`.
pub fn coset_fft(&self, shift: F) -> PolynomialValues<F> {
self.coset_fft_with_options(shift, None, None)
}
/// Returns the evaluation of the polynomial on the coset `shift*H`.
pub fn coset_fft_with_options(
&self,
shift: F,
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) -> PolynomialValues<F> {
let modified_poly: Self = shift
.powers()
.zip(&self.coeffs)
.map(|(r, &c)| r * c)
.collect::<Vec<_>>()
.into();
modified_poly.fft_with_options(zero_factor, root_table)
}
pub fn to_extension<const D: usize>(&self) -> PolynomialCoeffs<F::Extension>
where
F: Extendable<D>,
{
PolynomialCoeffs::new(self.coeffs.iter().map(|&c| c.into()).collect())
}
pub fn mul_extension<const D: usize>(&self, rhs: F::Extension) -> PolynomialCoeffs<F::Extension>
where
F: Extendable<D>,
{
PolynomialCoeffs::new(self.coeffs.iter().map(|&c| rhs.scalar_mul(c)).collect())
}
}
impl<F: Field> PartialEq for PolynomialCoeffs<F> {
fn eq(&self, other: &Self) -> bool {
let max_terms = self.coeffs.len().max(other.coeffs.len());
for i in 0..max_terms {
let self_i = self.coeffs.get(i).cloned().unwrap_or(F::ZERO);
let other_i = other.coeffs.get(i).cloned().unwrap_or(F::ZERO);
if self_i != other_i {
return false;
}
}
true
}
}
impl<F: Field> Eq for PolynomialCoeffs<F> {}
impl<F: Field> From<Vec<F>> for PolynomialCoeffs<F> {
fn from(coeffs: Vec<F>) -> Self {
Self::new(coeffs)
}
}
impl<F: Field> Add for &PolynomialCoeffs<F> {
type Output = PolynomialCoeffs<F>;
fn add(self, rhs: Self) -> Self::Output {
let len = max(self.len(), rhs.len());
let a = self.padded(len).coeffs;
let b = rhs.padded(len).coeffs;
let coeffs = a.into_iter().zip(b).map(|(x, y)| x + y).collect();
PolynomialCoeffs::new(coeffs)
}
}
impl<F: Field> Sum for PolynomialCoeffs<F> {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Self::empty(), |acc, p| &acc + &p)
}
}
impl<F: Field> Sub for &PolynomialCoeffs<F> {
type Output = PolynomialCoeffs<F>;
fn sub(self, rhs: Self) -> Self::Output {
let len = max(self.len(), rhs.len());
let mut coeffs = self.padded(len).coeffs;
for (i, &c) in rhs.coeffs.iter().enumerate() {
coeffs[i] -= c;
}
PolynomialCoeffs::new(coeffs)
}
}
impl<F: Field> AddAssign for PolynomialCoeffs<F> {
fn add_assign(&mut self, rhs: Self) {
let len = max(self.len(), rhs.len());
self.coeffs.resize(len, F::ZERO);
for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) {
*l += r;
}
}
}
impl<F: Field> AddAssign<&Self> for PolynomialCoeffs<F> {
fn add_assign(&mut self, rhs: &Self) {
let len = max(self.len(), rhs.len());
self.coeffs.resize(len, F::ZERO);
for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) {
*l += r;
}
}
}
impl<F: Field> SubAssign for PolynomialCoeffs<F> {
fn sub_assign(&mut self, rhs: Self) {
let len = max(self.len(), rhs.len());
self.coeffs.resize(len, F::ZERO);
for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) {
*l -= r;
}
}
}
impl<F: Field> SubAssign<&Self> for PolynomialCoeffs<F> {
fn sub_assign(&mut self, rhs: &Self) {
let len = max(self.len(), rhs.len());
self.coeffs.resize(len, F::ZERO);
for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) {
*l -= r;
}
}
}
impl<F: Field> Mul<F> for &PolynomialCoeffs<F> {
type Output = PolynomialCoeffs<F>;
fn mul(self, rhs: F) -> Self::Output {
let coeffs = self.coeffs.iter().map(|&x| rhs * x).collect();
PolynomialCoeffs::new(coeffs)
}
}
impl<F: Field> MulAssign<F> for PolynomialCoeffs<F> {
fn mul_assign(&mut self, rhs: F) {
self.coeffs.iter_mut().for_each(|x| *x *= rhs);
}
}
impl<F: Field> Mul for &PolynomialCoeffs<F> {
type Output = PolynomialCoeffs<F>;
#[allow(clippy::suspicious_arithmetic_impl)]
fn mul(self, rhs: Self) -> Self::Output {
let new_len = (self.len() + rhs.len()).next_power_of_two();
let a = self.padded(new_len);
let b = rhs.padded(new_len);
let a_evals = a.fft();
let b_evals = b.fft();
let mul_evals: Vec<F> = a_evals
.values
.into_iter()
.zip(b_evals.values)
.map(|(pa, pb)| pa * pb)
.collect();
ifft(&mul_evals.into())
}
}
#[cfg(test)]
mod tests {
use std::time::Instant;
use rand::{thread_rng, Rng};
use super::*;
use crate::field::goldilocks_field::GoldilocksField;
#[test]
fn test_trimmed() {
type F = GoldilocksField;
assert_eq!(
PolynomialCoeffs::<F> { coeffs: vec![] }.trimmed(),
PolynomialCoeffs::<F> { coeffs: vec![] }
);
assert_eq!(
PolynomialCoeffs::<F> {
coeffs: vec![F::ZERO]
}
.trimmed(),
PolynomialCoeffs::<F> { coeffs: vec![] }
);
assert_eq!(
PolynomialCoeffs::<F> {
coeffs: vec![F::ONE, F::TWO, F::ZERO, F::ZERO]
}
.trimmed(),
PolynomialCoeffs::<F> {
coeffs: vec![F::ONE, F::TWO]
}
);
}
#[test]
fn test_coset_fft() {
type F = GoldilocksField;
let k = 8;
let n = 1 << k;
let poly = PolynomialCoeffs::new(F::rand_vec(n));
let shift = F::rand();
let coset_evals = poly.coset_fft(shift).values;
let generator = F::primitive_root_of_unity(k);
let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n)
.into_iter()
.map(|x| poly.eval(x))
.collect::<Vec<_>>();
assert_eq!(coset_evals, naive_coset_evals);
let ifft_coeffs = PolynomialValues::new(coset_evals).coset_ifft(shift);
assert_eq!(poly, ifft_coeffs.into());
}
#[test]
fn test_coset_ifft() {
type F = GoldilocksField;
let k = 8;
let n = 1 << k;
let evals = PolynomialValues::new(F::rand_vec(n));
let shift = F::rand();
let coeffs = evals.coset_ifft(shift);
let generator = F::primitive_root_of_unity(k);
let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n)
.into_iter()
.map(|x| coeffs.eval(x))
.collect::<Vec<_>>();
assert_eq!(evals, naive_coset_evals.into());
let fft_evals = coeffs.coset_fft(shift);
assert_eq!(evals, fft_evals);
}
#[test]
fn test_polynomial_multiplication() {
type F = GoldilocksField;
let mut rng = thread_rng();
let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000));
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = PolynomialCoeffs::new(F::rand_vec(b_deg));
let m1 = &a * &b;
let m2 = &a * &b;
for _ in 0..1000 {
let x = F::rand();
assert_eq!(m1.eval(x), a.eval(x) * b.eval(x));
assert_eq!(m2.eval(x), a.eval(x) * b.eval(x));
}
}
#[test]
fn test_inv_mod_xn() {
type F = GoldilocksField;
let mut rng = thread_rng();
let a_deg = rng.gen_range(1..1_000);
let n = rng.gen_range(1..1_000);
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = a.inv_mod_xn(n);
let mut m = &a * &b;
m.coeffs.drain(n..);
m.trim();
assert_eq!(
m,
PolynomialCoeffs::new(vec![F::ONE]),
"a: {:#?}, b:{:#?}, n:{:#?}, m:{:#?}",
a,
b,
n,
m
);
}
#[test]
fn test_polynomial_long_division() {
type F = GoldilocksField;
let mut rng = thread_rng();
let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000));
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = PolynomialCoeffs::new(F::rand_vec(b_deg));
let (q, r) = a.div_rem_long_division(&b);
for _ in 0..1000 {
let x = F::rand();
assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x));
}
}
#[test]
fn test_polynomial_division() {
type F = GoldilocksField;
let mut rng = thread_rng();
let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000));
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = PolynomialCoeffs::new(F::rand_vec(b_deg));
let (q, r) = a.div_rem(&b);
for _ in 0..1000 {
let x = F::rand();
assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x));
}
}
#[test]
fn test_polynomial_division_by_constant() {
type F = GoldilocksField;
let mut rng = thread_rng();
let a_deg = rng.gen_range(1..10_000);
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = PolynomialCoeffs::from(vec![F::rand()]);
let (q, r) = a.div_rem(&b);
for _ in 0..1000 {
let x = F::rand();
assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x));
}
}
#[test]
fn test_division_by_z_h() {
type F = GoldilocksField;
let mut rng = thread_rng();
let a_deg = rng.gen_range(1..10_000);
let n = rng.gen_range(1..a_deg);
let mut a = PolynomialCoeffs::new(F::rand_vec(a_deg));
a.trim();
let z_h = {
let mut z_h_vec = vec![F::ZERO; n + 1];
z_h_vec[n] = F::ONE;
z_h_vec[0] = F::NEG_ONE;
PolynomialCoeffs::new(z_h_vec)
};
let m = &a * &z_h;
let now = Instant::now();
let mut a_test = m.divide_by_z_h(n);
a_test.trim();
println!("Division time: {:?}", now.elapsed());
assert_eq!(a, a_test);
}
#[test]
fn divide_zero_poly_by_z_h() {
let zero_poly = PolynomialCoeffs::<GoldilocksField>::empty();
zero_poly.divide_by_z_h(16);
}
// Test to see which polynomial division method is faster for divisions of the type
// `(X^n - 1)/(X - a)
#[test]
fn test_division_linear() {
type F = GoldilocksField;
let mut rng = thread_rng();
let l = 14;
let n = 1 << l;
let g = F::primitive_root_of_unity(l);
let xn_minus_one = {
let mut xn_min_one_vec = vec![F::ZERO; n + 1];
xn_min_one_vec[n] = F::ONE;
xn_min_one_vec[0] = F::NEG_ONE;
PolynomialCoeffs::new(xn_min_one_vec)
};
let a = g.exp_u64(rng.gen_range(0..(n as u64)));
let denom = PolynomialCoeffs::new(vec![-a, F::ONE]);
let now = Instant::now();
xn_minus_one.div_rem(&denom);
println!("Division time: {:?}", now.elapsed());
let now = Instant::now();
xn_minus_one.div_rem_long_division(&denom);
println!("Division time: {:?}", now.elapsed());
}
#[test]
fn eq() {
type F = GoldilocksField;
assert_eq!(
PolynomialCoeffs::<F>::new(vec![]),
PolynomialCoeffs::new(vec![])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![F::ZERO])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![]),
PolynomialCoeffs::new(vec![F::ZERO])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![F::ZERO, F::ZERO])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![F::ONE]),
PolynomialCoeffs::new(vec![F::ONE, F::ZERO])
);
assert_ne!(
PolynomialCoeffs::<F>::new(vec![]),
PolynomialCoeffs::new(vec![F::ONE])
);
assert_ne!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![F::ZERO, F::ONE])
);
assert_ne!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![F::ONE, F::ZERO])
);
}
}

Some files were not shown because too many files have changed in this diff Show More