Merge branch 'main' into generic_configuration

# Conflicts:
#	src/field/extension_field/mod.rs
#	src/fri/recursive_verifier.rs
#	src/gadgets/arithmetic.rs
#	src/gadgets/arithmetic_extension.rs
#	src/gadgets/hash.rs
#	src/gadgets/interpolation.rs
#	src/gadgets/random_access.rs
#	src/gadgets/sorting.rs
#	src/gates/arithmetic_u32.rs
#	src/gates/gate_tree.rs
#	src/gates/interpolation.rs
#	src/gates/poseidon.rs
#	src/gates/poseidon_mds.rs
#	src/gates/random_access.rs
#	src/hash/hashing.rs
#	src/hash/merkle_proofs.rs
#	src/hash/poseidon.rs
#	src/iop/challenger.rs
#	src/iop/generator.rs
#	src/iop/witness.rs
#	src/plonk/circuit_data.rs
#	src/plonk/proof.rs
#	src/plonk/prover.rs
#	src/plonk/recursive_verifier.rs
#	src/util/partial_products.rs
#	src/util/reducing.rs
This commit is contained in:
wborgeaud 2021-12-16 14:54:38 +01:00
commit bdbc8b6931
105 changed files with 8303 additions and 2545 deletions

View File

@ -28,9 +28,11 @@ jobs:
with:
command: test
args: --all
env:
RUSTFLAGS: -Copt-level=3 -Cdebug-assertions -Coverflow-checks=y
lints:
name: Formatting
name: Formatting and Clippy
runs-on: ubuntu-latest
if: "! contains(toJSON(github.event.commits.*.message), '[skip-ci]')"
steps:
@ -43,10 +45,17 @@ jobs:
profile: minimal
toolchain: nightly
override: true
components: rustfmt
components: rustfmt, clippy
- name: Run cargo fmt
uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check
- name: Run cargo clippy
uses: actions-rs/cargo@v1
with:
command: clippy
args: --all-features --all-targets -- -D warnings -A incomplete-features

View File

@ -8,7 +8,7 @@ license = "MIT OR Apache-2.0"
repository = "https://github.com/mir-protocol/plonky2"
keywords = ["cryptography", "SNARK", "FRI"]
categories = ["cryptography"]
edition = "2018"
edition = "2021"
default-run = "bench_recursion"
[dependencies]
@ -28,6 +28,9 @@ serde_cbor = "0.11.1"
keccak-hash = "0.8.0"
static_assertions = "1.1.0"
[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.3.2"
[dev-dependencies]
criterion = "0.3.5"
tynm = "0.1.6"

View File

@ -1,7 +1,7 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use plonky2::field::field_types::Field;
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::polynomial::polynomial::PolynomialCoeffs;
use plonky2::polynomial::PolynomialCoeffs;
use tynm::type_name;
pub(crate) fn bench_ffts<F: Field>(c: &mut Criterion) {

View File

@ -1,5 +1,3 @@
#![feature(destructuring_assignment)]
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use plonky2::field::extension_field::quartic::QuarticExtension;
use plonky2::field::field_types::Field;
@ -112,6 +110,66 @@ pub(crate) fn bench_field<F: Field>(c: &mut Criterion) {
c.bench_function(&format!("try_inverse<{}>", type_name::<F>()), |b| {
b.iter_batched(|| F::rand(), |x| x.try_inverse(), BatchSize::SmallInput)
});
c.bench_function(
&format!("batch_multiplicative_inverse-tiny<{}>", type_name::<F>()),
|b| {
b.iter_batched(
|| (0..2).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|x| F::batch_multiplicative_inverse(&x),
BatchSize::SmallInput,
)
},
);
c.bench_function(
&format!("batch_multiplicative_inverse-small<{}>", type_name::<F>()),
|b| {
b.iter_batched(
|| (0..4).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|x| F::batch_multiplicative_inverse(&x),
BatchSize::SmallInput,
)
},
);
c.bench_function(
&format!("batch_multiplicative_inverse-medium<{}>", type_name::<F>()),
|b| {
b.iter_batched(
|| (0..16).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|x| F::batch_multiplicative_inverse(&x),
BatchSize::SmallInput,
)
},
);
c.bench_function(
&format!("batch_multiplicative_inverse-large<{}>", type_name::<F>()),
|b| {
b.iter_batched(
|| (0..256).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|x| F::batch_multiplicative_inverse(&x),
BatchSize::LargeInput,
)
},
);
c.bench_function(
&format!("batch_multiplicative_inverse-huge<{}>", type_name::<F>()),
|b| {
b.iter_batched(
|| {
(0..65536)
.into_iter()
.map(|_| F::rand())
.collect::<Vec<_>>()
},
|x| F::batch_multiplicative_inverse(&x),
BatchSize::LargeInput,
)
},
);
}
fn criterion_benchmark(c: &mut Criterion) {

View File

@ -1,4 +1,3 @@
#![feature(destructuring_assignment)]
#![feature(generic_const_exprs)]
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
@ -19,7 +18,7 @@ pub(crate) fn bench_gmimc<F: GMiMC<WIDTH>, const WIDTH: usize>(c: &mut Criterion
pub(crate) fn bench_poseidon<F: Poseidon<WIDTH>, const WIDTH: usize>(c: &mut Criterion)
where
[(); WIDTH - 1]: ,
[(); WIDTH - 1]:,
{
c.bench_function(&format!("poseidon<{}, {}>", type_name::<F>(), WIDTH), |b| {
b.iter_batched(

View File

@ -2,7 +2,7 @@ use std::time::Instant;
use plonky2::field::field_types::Field;
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::polynomial::polynomial::PolynomialValues;
use plonky2::polynomial::PolynomialValues;
use rayon::prelude::*;
type F = GoldilocksField;

View File

@ -24,6 +24,7 @@ fn bench_prove<C: GenericConfig<D>, const D: usize>() -> Result<()> {
num_wires: 126,
num_routed_wires: 33,
constant_gate_size: 6,
use_base_arithmetic_gate: false,
security_bits: 128,
rate_bits: 3,
num_challenges: 3,

View File

@ -1,5 +1,7 @@
//! Generates random constants using ChaCha20, seeded with zero.
#![allow(clippy::needless_range_loop)]
use plonky2::field::field_types::PrimeField;
use plonky2::field::goldilocks_field::GoldilocksField;
use rand::{Rng, SeedableRng};

156
src/curve/curve_adds.rs Normal file
View File

@ -0,0 +1,156 @@
use std::ops::Add;
use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint};
use crate::field::field_types::Field;
impl<C: Curve> Add<ProjectivePoint<C>> for ProjectivePoint<C> {
type Output = ProjectivePoint<C>;
fn add(self, rhs: ProjectivePoint<C>) -> Self::Output {
let ProjectivePoint {
x: x1,
y: y1,
z: z1,
} = self;
let ProjectivePoint {
x: x2,
y: y2,
z: z2,
} = rhs;
if z1 == C::BaseField::ZERO {
return rhs;
}
if z2 == C::BaseField::ZERO {
return self;
}
let x1z2 = x1 * z2;
let y1z2 = y1 * z2;
let x2z1 = x2 * z1;
let y2z1 = y2 * z1;
// Check if we're doubling or adding inverses.
if x1z2 == x2z1 {
if y1z2 == y2z1 {
// TODO: inline to avoid redundant muls.
return self.double();
}
if y1z2 == -y2z1 {
return ProjectivePoint::ZERO;
}
}
// From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/add-1998-cmo-2
let z1z2 = z1 * z2;
let u = y2z1 - y1z2;
let uu = u.square();
let v = x2z1 - x1z2;
let vv = v.square();
let vvv = v * vv;
let r = vv * x1z2;
let a = uu * z1z2 - vvv - r.double();
let x3 = v * a;
let y3 = u * (r - a) - vvv * y1z2;
let z3 = vvv * z1z2;
ProjectivePoint::nonzero(x3, y3, z3)
}
}
impl<C: Curve> Add<AffinePoint<C>> for ProjectivePoint<C> {
type Output = ProjectivePoint<C>;
fn add(self, rhs: AffinePoint<C>) -> Self::Output {
let ProjectivePoint {
x: x1,
y: y1,
z: z1,
} = self;
let AffinePoint {
x: x2,
y: y2,
zero: zero2,
} = rhs;
if z1 == C::BaseField::ZERO {
return rhs.to_projective();
}
if zero2 {
return self;
}
let x2z1 = x2 * z1;
let y2z1 = y2 * z1;
// Check if we're doubling or adding inverses.
if x1 == x2z1 {
if y1 == y2z1 {
// TODO: inline to avoid redundant muls.
return self.double();
}
if y1 == -y2z1 {
return ProjectivePoint::ZERO;
}
}
// From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/madd-1998-cmo
let u = y2z1 - y1;
let uu = u.square();
let v = x2z1 - x1;
let vv = v.square();
let vvv = v * vv;
let r = vv * x1;
let a = uu * z1 - vvv - r.double();
let x3 = v * a;
let y3 = u * (r - a) - vvv * y1;
let z3 = vvv * z1;
ProjectivePoint::nonzero(x3, y3, z3)
}
}
impl<C: Curve> Add<AffinePoint<C>> for AffinePoint<C> {
type Output = ProjectivePoint<C>;
fn add(self, rhs: AffinePoint<C>) -> Self::Output {
let AffinePoint {
x: x1,
y: y1,
zero: zero1,
} = self;
let AffinePoint {
x: x2,
y: y2,
zero: zero2,
} = rhs;
if zero1 {
return rhs.to_projective();
}
if zero2 {
return self.to_projective();
}
// Check if we're doubling or adding inverses.
if x1 == x2 {
if y1 == y2 {
return self.to_projective().double();
}
if y1 == -y2 {
return ProjectivePoint::ZERO;
}
}
// From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/mmadd-1998-cmo
let u = y2 - y1;
let uu = u.square();
let v = x2 - x1;
let vv = v.square();
let vvv = v * vv;
let r = vv * x1;
let a = uu - vvv - r.double();
let x3 = v * a;
let y3 = u * (r - a) - vvv * y1;
let z3 = vvv;
ProjectivePoint::nonzero(x3, y3, z3)
}
}

263
src/curve/curve_msm.rs Normal file
View File

@ -0,0 +1,263 @@
use itertools::Itertools;
use rayon::prelude::*;
use crate::curve::curve_summation::affine_multisummation_best;
use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint};
use crate::field::field_types::Field;
/// In Yao's method, we compute an affine summation for each digit. In a parallel setting, it would
/// be easiest to assign individual summations to threads, but this would be sub-optimal because
/// multi-summations can be more efficient than repeating individual summations (see
/// `affine_multisummation_best`). Thus we divide digits into large chunks, and assign chunks of
/// digits to threads. Note that there is a delicate balance here, as large chunks can result in
/// uneven distributions of work among threads.
const DIGITS_PER_CHUNK: usize = 80;
#[derive(Clone, Debug)]
pub struct MsmPrecomputation<C: Curve> {
/// For each generator (in the order they were passed to `msm_precompute`), contains a vector
/// of powers, i.e. [(2^w)^i] for i < DIGITS.
// TODO: Use compressed coordinates here.
powers_per_generator: Vec<Vec<AffinePoint<C>>>,
/// The window size.
w: usize,
}
pub fn msm_precompute<C: Curve>(
generators: &[ProjectivePoint<C>],
w: usize,
) -> MsmPrecomputation<C> {
MsmPrecomputation {
powers_per_generator: generators
.into_par_iter()
.map(|&g| precompute_single_generator(g, w))
.collect(),
w,
}
}
fn precompute_single_generator<C: Curve>(g: ProjectivePoint<C>, w: usize) -> Vec<AffinePoint<C>> {
let digits = (C::ScalarField::BITS + w - 1) / w;
let mut powers: Vec<ProjectivePoint<C>> = Vec::with_capacity(digits);
powers.push(g);
for i in 1..digits {
let mut power_i_proj = powers[i - 1];
for _j in 0..w {
power_i_proj = power_i_proj.double();
}
powers.push(power_i_proj);
}
ProjectivePoint::batch_to_affine(&powers)
}
pub fn msm_parallel<C: Curve>(
scalars: &[C::ScalarField],
generators: &[ProjectivePoint<C>],
w: usize,
) -> ProjectivePoint<C> {
let precomputation = msm_precompute(generators, w);
msm_execute_parallel(&precomputation, scalars)
}
pub fn msm_execute<C: Curve>(
precomputation: &MsmPrecomputation<C>,
scalars: &[C::ScalarField],
) -> ProjectivePoint<C> {
assert_eq!(precomputation.powers_per_generator.len(), scalars.len());
let w = precomputation.w;
let digits = (C::ScalarField::BITS + w - 1) / w;
let base = 1 << w;
// This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use
// extremely large windows, the repeated scans in Yao's method could be more expensive than the
// actual group operations. To avoid this, we store a multimap from each possible digit to the
// positions in which that digit occurs in the scalars. These positions have the form (i, j),
// where i is the index of the generator and j is an index into the digits of the scalar
// associated with that generator.
let mut digit_occurrences: Vec<Vec<(usize, usize)>> = Vec::with_capacity(digits);
for _i in 0..base {
digit_occurrences.push(Vec::new());
}
for (i, scalar) in scalars.iter().enumerate() {
let digits = to_digits::<C>(scalar, w);
for (j, &digit) in digits.iter().enumerate() {
digit_occurrences[digit].push((i, j));
}
}
let mut y = ProjectivePoint::ZERO;
let mut u = ProjectivePoint::ZERO;
for digit in (1..base).rev() {
for &(i, j) in &digit_occurrences[digit] {
u = u + precomputation.powers_per_generator[i][j];
}
y = y + u;
}
y
}
pub fn msm_execute_parallel<C: Curve>(
precomputation: &MsmPrecomputation<C>,
scalars: &[C::ScalarField],
) -> ProjectivePoint<C> {
assert_eq!(precomputation.powers_per_generator.len(), scalars.len());
let w = precomputation.w;
let digits = (C::ScalarField::BITS + w - 1) / w;
let base = 1 << w;
// This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use
// extremely large windows, the repeated scans in Yao's method could be more expensive than the
// actual group operations. To avoid this, we store a multimap from each possible digit to the
// positions in which that digit occurs in the scalars. These positions have the form (i, j),
// where i is the index of the generator and j is an index into the digits of the scalar
// associated with that generator.
let mut digit_occurrences: Vec<Vec<(usize, usize)>> = Vec::with_capacity(digits);
for _i in 0..base {
digit_occurrences.push(Vec::new());
}
for (i, scalar) in scalars.iter().enumerate() {
let digits = to_digits::<C>(scalar, w);
for (j, &digit) in digits.iter().enumerate() {
digit_occurrences[digit].push((i, j));
}
}
// For each digit, we add up the powers associated with all occurrences that digit.
let digits: Vec<usize> = (0..base).collect();
let digit_acc: Vec<ProjectivePoint<C>> = digits
.par_chunks(DIGITS_PER_CHUNK)
.flat_map(|chunk| {
let summations: Vec<Vec<AffinePoint<C>>> = chunk
.iter()
.map(|&digit| {
digit_occurrences[digit]
.iter()
.map(|&(i, j)| precomputation.powers_per_generator[i][j])
.collect()
})
.collect();
affine_multisummation_best(summations)
})
.collect();
// println!("Computing the per-digit summations (in parallel) took {}s", start.elapsed().as_secs_f64());
let mut y = ProjectivePoint::ZERO;
let mut u = ProjectivePoint::ZERO;
for digit in (1..base).rev() {
u = u + digit_acc[digit];
y = y + u;
}
// println!("Final summation (sequential) {}s", start.elapsed().as_secs_f64());
y
}
pub(crate) fn to_digits<C: Curve>(x: &C::ScalarField, w: usize) -> Vec<usize> {
let scalar_bits = C::ScalarField::BITS;
let num_digits = (scalar_bits + w - 1) / w;
// Convert x to a bool array.
let x_canonical: Vec<_> = x
.to_biguint()
.to_u64_digits()
.iter()
.cloned()
.pad_using(scalar_bits / 64, |_| 0)
.collect();
let mut x_bits = Vec::with_capacity(scalar_bits);
for i in 0..scalar_bits {
x_bits.push((x_canonical[i / 64] >> (i as u64 % 64) & 1) != 0);
}
let mut digits = Vec::with_capacity(num_digits);
for i in 0..num_digits {
let mut digit = 0;
for j in ((i * w)..((i + 1) * w).min(scalar_bits)).rev() {
digit <<= 1;
digit |= x_bits[j] as usize;
}
digits.push(digit);
}
digits
}
#[cfg(test)]
mod tests {
use num::BigUint;
use crate::curve::curve_msm::{msm_execute, msm_precompute, to_digits};
use crate::curve::curve_types::Curve;
use crate::curve::secp256k1::Secp256K1;
use crate::field::field_types::Field;
use crate::field::secp256k1_scalar::Secp256K1Scalar;
#[test]
fn test_to_digits() {
let x_canonical = [
0b10101010101010101010101010101010,
0b10101010101010101010101010101010,
0b11001100110011001100110011001100,
0b11001100110011001100110011001100,
0b11110000111100001111000011110000,
0b11110000111100001111000011110000,
0b00001111111111111111111111111111,
0b11111111111111111111111111111111,
];
let x = Secp256K1Scalar::from_biguint(BigUint::from_slice(&x_canonical));
assert_eq!(x.to_biguint().to_u32_digits(), x_canonical);
assert_eq!(
to_digits::<Secp256K1>(&x, 17),
vec![
0b01010101010101010,
0b10101010101010101,
0b01010101010101010,
0b11001010101010101,
0b01100110011001100,
0b00110011001100110,
0b10011001100110011,
0b11110000110011001,
0b01111000011110000,
0b00111100001111000,
0b00011110000111100,
0b11111111111111110,
0b01111111111111111,
0b11111111111111000,
0b11111111111111111,
0b1,
]
);
}
#[test]
fn test_msm() {
let w = 5;
let generator_1 = Secp256K1::GENERATOR_PROJECTIVE;
let generator_2 = generator_1 + generator_1;
let generator_3 = generator_1 + generator_2;
let scalar_1 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[
11111111, 22222222, 33333333, 44444444,
]));
let scalar_2 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[
22222222, 22222222, 33333333, 44444444,
]));
let scalar_3 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[
33333333, 22222222, 33333333, 44444444,
]));
let generators = vec![generator_1, generator_2, generator_3];
let scalars = vec![scalar_1, scalar_2, scalar_3];
let precomputation = msm_precompute(&generators, w);
let result_msm = msm_execute(&precomputation, &scalars);
let result_naive = Secp256K1::convert(scalar_1) * generator_1
+ Secp256K1::convert(scalar_2) * generator_2
+ Secp256K1::convert(scalar_3) * generator_3;
assert_eq!(result_msm, result_naive);
}
}

View File

@ -0,0 +1,97 @@
use std::ops::Mul;
use crate::curve::curve_types::{Curve, CurveScalar, ProjectivePoint};
use crate::field::field_types::Field;
const WINDOW_BITS: usize = 4;
const BASE: usize = 1 << WINDOW_BITS;
fn digits_per_scalar<C: Curve>() -> usize {
(C::ScalarField::BITS + WINDOW_BITS - 1) / WINDOW_BITS
}
/// Precomputed state used for scalar x ProjectivePoint multiplications,
/// specific to a particular generator.
#[derive(Clone)]
pub struct MultiplicationPrecomputation<C: Curve> {
/// [(2^w)^i] g for each i < digits_per_scalar.
powers: Vec<ProjectivePoint<C>>,
}
impl<C: Curve> ProjectivePoint<C> {
pub fn mul_precompute(&self) -> MultiplicationPrecomputation<C> {
let num_digits = digits_per_scalar::<C>();
let mut powers = Vec::with_capacity(num_digits);
powers.push(*self);
for i in 1..num_digits {
let mut power_i = powers[i - 1];
for _j in 0..WINDOW_BITS {
power_i = power_i.double();
}
powers.push(power_i);
}
MultiplicationPrecomputation { powers }
}
pub fn mul_with_precomputation(
&self,
scalar: C::ScalarField,
precomputation: MultiplicationPrecomputation<C>,
) -> Self {
// Yao's method; see https://koclab.cs.ucsb.edu/teaching/ecc/eccPapers/Doche-ch09.pdf
let precomputed_powers = precomputation.powers;
let digits = to_digits::<C>(&scalar);
let mut y = ProjectivePoint::ZERO;
let mut u = ProjectivePoint::ZERO;
let mut all_summands = Vec::new();
for j in (1..BASE).rev() {
let mut u_summands = Vec::new();
for (i, &digit) in digits.iter().enumerate() {
if digit == j as u64 {
u_summands.push(precomputed_powers[i]);
}
}
all_summands.push(u_summands);
}
let all_sums: Vec<ProjectivePoint<C>> = all_summands
.iter()
.cloned()
.map(|vec| vec.iter().fold(ProjectivePoint::ZERO, |a, &b| a + b))
.collect();
for i in 0..all_sums.len() {
u = u + all_sums[i];
y = y + u;
}
y
}
}
impl<C: Curve> Mul<ProjectivePoint<C>> for CurveScalar<C> {
type Output = ProjectivePoint<C>;
fn mul(self, rhs: ProjectivePoint<C>) -> Self::Output {
let precomputation = rhs.mul_precompute();
rhs.mul_with_precomputation(self.0, precomputation)
}
}
#[allow(clippy::assertions_on_constants)]
fn to_digits<C: Curve>(x: &C::ScalarField) -> Vec<u64> {
debug_assert!(
64 % WINDOW_BITS == 0,
"For simplicity, only power-of-two window sizes are handled for now"
);
let digits_per_u64 = 64 / WINDOW_BITS;
let mut digits = Vec::with_capacity(digits_per_scalar::<C>());
for limb in x.to_biguint().to_u64_digits() {
for j in 0..digits_per_u64 {
digits.push((limb >> (j * WINDOW_BITS) as u64) % BASE as u64);
}
}
digits
}

View File

@ -0,0 +1,237 @@
use std::iter::Sum;
use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint};
use crate::field::field_types::Field;
impl<C: Curve> Sum<AffinePoint<C>> for ProjectivePoint<C> {
fn sum<I: Iterator<Item = AffinePoint<C>>>(iter: I) -> ProjectivePoint<C> {
let points: Vec<_> = iter.collect();
affine_summation_best(points)
}
}
impl<C: Curve> Sum for ProjectivePoint<C> {
fn sum<I: Iterator<Item = ProjectivePoint<C>>>(iter: I) -> ProjectivePoint<C> {
iter.fold(ProjectivePoint::ZERO, |acc, x| acc + x)
}
}
pub fn affine_summation_best<C: Curve>(summation: Vec<AffinePoint<C>>) -> ProjectivePoint<C> {
let result = affine_multisummation_best(vec![summation]);
debug_assert_eq!(result.len(), 1);
result[0]
}
pub fn affine_multisummation_best<C: Curve>(
summations: Vec<Vec<AffinePoint<C>>>,
) -> Vec<ProjectivePoint<C>> {
let pairwise_sums: usize = summations.iter().map(|summation| summation.len() / 2).sum();
// This threshold is chosen based on data from the summation benchmarks.
if pairwise_sums < 70 {
affine_multisummation_pairwise(summations)
} else {
affine_multisummation_batch_inversion(summations)
}
}
/// Adds each pair of points using an affine + affine = projective formula, then adds up the
/// intermediate sums using a projective formula.
pub fn affine_multisummation_pairwise<C: Curve>(
summations: Vec<Vec<AffinePoint<C>>>,
) -> Vec<ProjectivePoint<C>> {
summations
.into_iter()
.map(affine_summation_pairwise)
.collect()
}
/// Adds each pair of points using an affine + affine = projective formula, then adds up the
/// intermediate sums using a projective formula.
pub fn affine_summation_pairwise<C: Curve>(points: Vec<AffinePoint<C>>) -> ProjectivePoint<C> {
let mut reduced_points: Vec<ProjectivePoint<C>> = Vec::new();
for chunk in points.chunks(2) {
match chunk.len() {
1 => reduced_points.push(chunk[0].to_projective()),
2 => reduced_points.push(chunk[0] + chunk[1]),
_ => panic!(),
}
}
// TODO: Avoid copying (deref)
reduced_points
.iter()
.fold(ProjectivePoint::ZERO, |sum, x| sum + *x)
}
/// Computes several summations of affine points by applying an affine group law, except that the
/// divisions are batched via Montgomery's trick.
pub fn affine_summation_batch_inversion<C: Curve>(
summation: Vec<AffinePoint<C>>,
) -> ProjectivePoint<C> {
let result = affine_multisummation_batch_inversion(vec![summation]);
debug_assert_eq!(result.len(), 1);
result[0]
}
/// Computes several summations of affine points by applying an affine group law, except that the
/// divisions are batched via Montgomery's trick.
pub fn affine_multisummation_batch_inversion<C: Curve>(
summations: Vec<Vec<AffinePoint<C>>>,
) -> Vec<ProjectivePoint<C>> {
let mut elements_to_invert = Vec::new();
// For each pair of points, (x1, y1) and (x2, y2), that we're going to add later, we want to
// invert either y (if the points are equal) or x1 - x2 (otherwise). We will use these later.
for summation in &summations {
let n = summation.len();
// The special case for n=0 is to avoid underflow.
let range_end = if n == 0 { 0 } else { n - 1 };
for i in (0..range_end).step_by(2) {
let p1 = summation[i];
let p2 = summation[i + 1];
let AffinePoint {
x: x1,
y: y1,
zero: zero1,
} = p1;
let AffinePoint {
x: x2,
y: _y2,
zero: zero2,
} = p2;
if zero1 || zero2 || p1 == -p2 {
// These are trivial cases where we won't need any inverse.
} else if p1 == p2 {
elements_to_invert.push(y1.double());
} else {
elements_to_invert.push(x1 - x2);
}
}
}
let inverses: Vec<C::BaseField> =
C::BaseField::batch_multiplicative_inverse(&elements_to_invert);
let mut all_reduced_points = Vec::with_capacity(summations.len());
let mut inverse_index = 0;
for summation in summations {
let n = summation.len();
let mut reduced_points = Vec::with_capacity((n + 1) / 2);
// The special case for n=0 is to avoid underflow.
let range_end = if n == 0 { 0 } else { n - 1 };
for i in (0..range_end).step_by(2) {
let p1 = summation[i];
let p2 = summation[i + 1];
let AffinePoint {
x: x1,
y: y1,
zero: zero1,
} = p1;
let AffinePoint {
x: x2,
y: y2,
zero: zero2,
} = p2;
let sum = if zero1 {
p2
} else if zero2 {
p1
} else if p1 == -p2 {
AffinePoint::ZERO
} else {
// It's a non-trivial case where we need one of the inverses we computed earlier.
let inverse = inverses[inverse_index];
inverse_index += 1;
if p1 == p2 {
// This is the doubling case.
let mut numerator = x1.square().triple();
if C::A.is_nonzero() {
numerator += C::A;
}
let quotient = numerator * inverse;
let x3 = quotient.square() - x1.double();
let y3 = quotient * (x1 - x3) - y1;
AffinePoint::nonzero(x3, y3)
} else {
// This is the general case. We use the incomplete addition formulas 4.3 and 4.4.
let quotient = (y1 - y2) * inverse;
let x3 = quotient.square() - x1 - x2;
let y3 = quotient * (x1 - x3) - y1;
AffinePoint::nonzero(x3, y3)
}
};
reduced_points.push(sum);
}
// If n is odd, the last point was not part of a pair.
if n % 2 == 1 {
reduced_points.push(summation[n - 1]);
}
all_reduced_points.push(reduced_points);
}
// We should have consumed all of the inverses from the batch computation.
debug_assert_eq!(inverse_index, inverses.len());
// Recurse with our smaller set of points.
affine_multisummation_best(all_reduced_points)
}
#[cfg(test)]
mod tests {
use crate::curve::curve_summation::{
affine_summation_batch_inversion, affine_summation_pairwise,
};
use crate::curve::curve_types::{Curve, ProjectivePoint};
use crate::curve::secp256k1::Secp256K1;
#[test]
fn test_pairwise_affine_summation() {
let g_affine = Secp256K1::GENERATOR_AFFINE;
let g2_affine = (g_affine + g_affine).to_affine();
let g3_affine = (g_affine + g_affine + g_affine).to_affine();
let g2_proj = g2_affine.to_projective();
let g3_proj = g3_affine.to_projective();
assert_eq!(
affine_summation_pairwise::<Secp256K1>(vec![g_affine, g_affine]),
g2_proj
);
assert_eq!(
affine_summation_pairwise::<Secp256K1>(vec![g_affine, g2_affine]),
g3_proj
);
assert_eq!(
affine_summation_pairwise::<Secp256K1>(vec![g_affine, g_affine, g_affine]),
g3_proj
);
assert_eq!(
affine_summation_pairwise::<Secp256K1>(vec![]),
ProjectivePoint::ZERO
);
}
#[test]
fn test_pairwise_affine_summation_batch_inversion() {
let g = Secp256K1::GENERATOR_AFFINE;
let g_proj = g.to_projective();
assert_eq!(
affine_summation_batch_inversion::<Secp256K1>(vec![g, g]),
g_proj + g_proj
);
assert_eq!(
affine_summation_batch_inversion::<Secp256K1>(vec![g, g, g]),
g_proj + g_proj + g_proj
);
assert_eq!(
affine_summation_batch_inversion::<Secp256K1>(vec![]),
ProjectivePoint::ZERO
);
}
}

260
src/curve/curve_types.rs Normal file
View File

@ -0,0 +1,260 @@
use std::fmt::Debug;
use std::ops::Neg;
use crate::field::field_types::Field;
// To avoid implementation conflicts from associated types,
// see https://github.com/rust-lang/rust/issues/20400
pub struct CurveScalar<C: Curve>(pub <C as Curve>::ScalarField);
/// A short Weierstrass curve.
pub trait Curve: 'static + Sync + Sized + Copy + Debug {
type BaseField: Field;
type ScalarField: Field;
const A: Self::BaseField;
const B: Self::BaseField;
const GENERATOR_AFFINE: AffinePoint<Self>;
const GENERATOR_PROJECTIVE: ProjectivePoint<Self> = ProjectivePoint {
x: Self::GENERATOR_AFFINE.x,
y: Self::GENERATOR_AFFINE.y,
z: Self::BaseField::ONE,
};
fn convert(x: Self::ScalarField) -> CurveScalar<Self> {
CurveScalar(x)
}
fn is_safe_curve() -> bool {
// Added additional check to prevent using vulnerabilties in case a discriminant is equal to 0.
(Self::A.cube().double().double() + Self::B.square().triple().triple().triple())
.is_nonzero()
}
}
/// A point on a short Weierstrass curve, represented in affine coordinates.
#[derive(Copy, Clone, Debug)]
pub struct AffinePoint<C: Curve> {
pub x: C::BaseField,
pub y: C::BaseField,
pub zero: bool,
}
impl<C: Curve> AffinePoint<C> {
pub const ZERO: Self = Self {
x: C::BaseField::ZERO,
y: C::BaseField::ZERO,
zero: true,
};
pub fn nonzero(x: C::BaseField, y: C::BaseField) -> Self {
let point = Self { x, y, zero: false };
debug_assert!(point.is_valid());
point
}
pub fn is_valid(&self) -> bool {
let Self { x, y, zero } = *self;
zero || y.square() == x.cube() + C::A * x + C::B
}
pub fn to_projective(&self) -> ProjectivePoint<C> {
let Self { x, y, zero } = *self;
let z = if zero {
C::BaseField::ZERO
} else {
C::BaseField::ONE
};
ProjectivePoint { x, y, z }
}
pub fn batch_to_projective(affine_points: &[Self]) -> Vec<ProjectivePoint<C>> {
affine_points.iter().map(Self::to_projective).collect()
}
pub fn double(&self) -> Self {
let AffinePoint { x: x1, y: y1, zero } = *self;
if zero {
return AffinePoint::ZERO;
}
let double_y = y1.double();
let inv_double_y = double_y.inverse(); // (2y)^(-1)
let triple_xx = x1.square().triple(); // 3x^2
let lambda = (triple_xx + C::A) * inv_double_y;
let x3 = lambda.square() - self.x.double();
let y3 = lambda * (x1 - x3) - y1;
Self {
x: x3,
y: y3,
zero: false,
}
}
}
impl<C: Curve> PartialEq for AffinePoint<C> {
fn eq(&self, other: &Self) -> bool {
let AffinePoint {
x: x1,
y: y1,
zero: zero1,
} = *self;
let AffinePoint {
x: x2,
y: y2,
zero: zero2,
} = *other;
if zero1 || zero2 {
return zero1 == zero2;
}
x1 == x2 && y1 == y2
}
}
impl<C: Curve> Eq for AffinePoint<C> {}
/// A point on a short Weierstrass curve, represented in projective coordinates.
#[derive(Copy, Clone, Debug)]
pub struct ProjectivePoint<C: Curve> {
pub x: C::BaseField,
pub y: C::BaseField,
pub z: C::BaseField,
}
impl<C: Curve> ProjectivePoint<C> {
pub const ZERO: Self = Self {
x: C::BaseField::ZERO,
y: C::BaseField::ONE,
z: C::BaseField::ZERO,
};
pub fn nonzero(x: C::BaseField, y: C::BaseField, z: C::BaseField) -> Self {
let point = Self { x, y, z };
debug_assert!(point.is_valid());
point
}
pub fn is_valid(&self) -> bool {
let Self { x, y, z } = *self;
z.is_zero() || y.square() * z == x.cube() + C::A * x * z.square() + C::B * z.cube()
}
pub fn to_affine(&self) -> AffinePoint<C> {
let Self { x, y, z } = *self;
if z == C::BaseField::ZERO {
AffinePoint::ZERO
} else {
let z_inv = z.inverse();
AffinePoint::nonzero(x * z_inv, y * z_inv)
}
}
pub fn batch_to_affine(proj_points: &[Self]) -> Vec<AffinePoint<C>> {
let n = proj_points.len();
let zs: Vec<C::BaseField> = proj_points.iter().map(|pp| pp.z).collect();
let z_invs = C::BaseField::batch_multiplicative_inverse(&zs);
let mut result = Vec::with_capacity(n);
for i in 0..n {
let Self { x, y, z } = proj_points[i];
result.push(if z == C::BaseField::ZERO {
AffinePoint::ZERO
} else {
let z_inv = z_invs[i];
AffinePoint::nonzero(x * z_inv, y * z_inv)
});
}
result
}
// From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/doubling/dbl-2007-bl
pub fn double(&self) -> Self {
let Self { x, y, z } = *self;
if z == C::BaseField::ZERO {
return ProjectivePoint::ZERO;
}
let xx = x.square();
let zz = z.square();
let mut w = xx.triple();
if C::A.is_nonzero() {
w += C::A * zz;
}
let s = y.double() * z;
let r = y * s;
let rr = r.square();
let b = (x + r).square() - (xx + rr);
let h = w.square() - b.double();
let x3 = h * s;
let y3 = w * (b - h) - rr.double();
let z3 = s.cube();
Self {
x: x3,
y: y3,
z: z3,
}
}
pub fn add_slices(a: &[Self], b: &[Self]) -> Vec<Self> {
assert_eq!(a.len(), b.len());
a.iter()
.zip(b.iter())
.map(|(&a_i, &b_i)| a_i + b_i)
.collect()
}
pub fn neg(&self) -> Self {
Self {
x: self.x,
y: -self.y,
z: self.z,
}
}
}
impl<C: Curve> PartialEq for ProjectivePoint<C> {
fn eq(&self, other: &Self) -> bool {
let ProjectivePoint {
x: x1,
y: y1,
z: z1,
} = *self;
let ProjectivePoint {
x: x2,
y: y2,
z: z2,
} = *other;
if z1 == C::BaseField::ZERO || z2 == C::BaseField::ZERO {
return z1 == z2;
}
// We want to compare (x1/z1, y1/z1) == (x2/z2, y2/z2).
// But to avoid field division, it is better to compare (x1*z2, y1*z2) == (x2*z1, y2*z1).
x1 * z2 == x2 * z1 && y1 * z2 == y2 * z1
}
}
impl<C: Curve> Eq for ProjectivePoint<C> {}
impl<C: Curve> Neg for AffinePoint<C> {
type Output = AffinePoint<C>;
fn neg(self) -> Self::Output {
let AffinePoint { x, y, zero } = self;
AffinePoint { x, y: -y, zero }
}
}
impl<C: Curve> Neg for ProjectivePoint<C> {
type Output = ProjectivePoint<C>;
fn neg(self) -> Self::Output {
let ProjectivePoint { x, y, z } = self;
ProjectivePoint { x, y: -y, z }
}
}

6
src/curve/mod.rs Normal file
View File

@ -0,0 +1,6 @@
pub mod curve_adds;
pub mod curve_msm;
pub mod curve_multiplication;
pub mod curve_summation;
pub mod curve_types;
pub mod secp256k1;

98
src/curve/secp256k1.rs Normal file
View File

@ -0,0 +1,98 @@
use crate::curve::curve_types::{AffinePoint, Curve};
use crate::field::field_types::Field;
use crate::field::secp256k1_base::Secp256K1Base;
use crate::field::secp256k1_scalar::Secp256K1Scalar;
#[derive(Debug, Copy, Clone)]
pub struct Secp256K1;
impl Curve for Secp256K1 {
type BaseField = Secp256K1Base;
type ScalarField = Secp256K1Scalar;
const A: Secp256K1Base = Secp256K1Base::ZERO;
const B: Secp256K1Base = Secp256K1Base([7, 0, 0, 0]);
const GENERATOR_AFFINE: AffinePoint<Self> = AffinePoint {
x: SECP256K1_GENERATOR_X,
y: SECP256K1_GENERATOR_Y,
zero: false,
};
}
// 55066263022277343669578718895168534326250603453777594175500187360389116729240
const SECP256K1_GENERATOR_X: Secp256K1Base = Secp256K1Base([
0x59F2815B16F81798,
0x029BFCDB2DCE28D9,
0x55A06295CE870B07,
0x79BE667EF9DCBBAC,
]);
/// 32670510020758816978083085130507043184471273380659243275938904335757337482424
const SECP256K1_GENERATOR_Y: Secp256K1Base = Secp256K1Base([
0x9C47D08FFB10D4B8,
0xFD17B448A6855419,
0x5DA4FBFC0E1108A8,
0x483ADA7726A3C465,
]);
#[cfg(test)]
mod tests {
use num::BigUint;
use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint};
use crate::curve::secp256k1::Secp256K1;
use crate::field::field_types::Field;
use crate::field::secp256k1_scalar::Secp256K1Scalar;
#[test]
fn test_generator() {
let g = Secp256K1::GENERATOR_AFFINE;
assert!(g.is_valid());
let neg_g = AffinePoint::<Secp256K1> {
x: g.x,
y: -g.y,
zero: g.zero,
};
assert!(neg_g.is_valid());
}
#[test]
fn test_naive_multiplication() {
let g = Secp256K1::GENERATOR_PROJECTIVE;
let ten = Secp256K1Scalar::from_canonical_u64(10);
let product = mul_naive(ten, g);
let sum = g + g + g + g + g + g + g + g + g + g;
assert_eq!(product, sum);
}
#[test]
fn test_g1_multiplication() {
let lhs = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[
1111, 2222, 3333, 4444, 5555, 6666, 7777, 8888,
]));
assert_eq!(
Secp256K1::convert(lhs) * Secp256K1::GENERATOR_PROJECTIVE,
mul_naive(lhs, Secp256K1::GENERATOR_PROJECTIVE)
);
}
/// A simple, somewhat inefficient implementation of multiplication which is used as a reference
/// for correctness.
fn mul_naive(
lhs: Secp256K1Scalar,
rhs: ProjectivePoint<Secp256K1>,
) -> ProjectivePoint<Secp256K1> {
let mut g = rhs;
let mut sum = ProjectivePoint::ZERO;
for limb in lhs.to_biguint().to_u64_digits().iter() {
for j in 0..64 {
if (limb >> j & 1u64) != 0u64 {
sum = sum + g;
}
g = g.double();
}
}
sum
}
}

View File

@ -31,8 +31,6 @@ mod tests {
#[test]
fn distinct_cosets() {
// TODO: Switch to a smaller test field so that collision rejection is likely to occur.
type F = GoldilocksField;
const SUBGROUP_BITS: usize = 5;
const NUM_SHIFTS: usize = 50;

View File

@ -160,12 +160,32 @@ impl<F: OEF<D>, const D: usize> PolynomialCoeffsAlgebra<F, D> {
.fold(ExtensionAlgebra::ZERO, |acc, &c| acc * x + c)
}
/// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1.
pub fn eval_with_powers(&self, powers: &[ExtensionAlgebra<F, D>]) -> ExtensionAlgebra<F, D> {
debug_assert_eq!(self.coeffs.len(), powers.len() + 1);
let acc = self.coeffs[0];
self.coeffs[1..]
.iter()
.zip(powers)
.fold(acc, |acc, (&x, &c)| acc + c * x)
}
pub fn eval_base(&self, x: F) -> ExtensionAlgebra<F, D> {
self.coeffs
.iter()
.rev()
.fold(ExtensionAlgebra::ZERO, |acc, &c| acc.scalar_mul(x) + c)
}
/// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1.
pub fn eval_base_with_powers(&self, powers: &[F]) -> ExtensionAlgebra<F, D> {
debug_assert_eq!(self.coeffs.len(), powers.len() + 1);
let acc = self.coeffs[0];
self.coeffs[1..]
.iter()
.zip(powers)
.fold(acc, |acc, (&x, &c)| acc + x.scalar_mul(c))
}
}
#[cfg(test)]

View File

@ -1,3 +1,4 @@
use crate::field::field_types::{Field, PrimeField};
use std::convert::TryInto;
use crate::field::field_types::{Field, RichField};

View File

@ -3,6 +3,7 @@ use std::iter::{Product, Sum};
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use num::bigint::BigUint;
use num::Integer;
use rand::Rng;
use serde::{Deserialize, Serialize};
@ -49,26 +50,28 @@ impl<F: Extendable<2>> From<F> for QuadraticExtension<F> {
}
impl<F: Extendable<2>> Field for QuadraticExtension<F> {
type PrimeField = F;
const ZERO: Self = Self([F::ZERO; 2]);
const ONE: Self = Self([F::ONE, F::ZERO]);
const TWO: Self = Self([F::TWO, F::ZERO]);
const NEG_ONE: Self = Self([F::NEG_ONE, F::ZERO]);
const CHARACTERISTIC: u64 = F::CHARACTERISTIC;
// `p^2 - 1 = (p - 1)(p + 1)`. The `p - 1` term has a two-adicity of `F::TWO_ADICITY`. As
// long as `F::TWO_ADICITY >= 2`, `p` can be written as `4n + 1`, so `p + 1` can be written as
// `2(2n + 1)`, which has a 2-adicity of 1.
const TWO_ADICITY: usize = F::TWO_ADICITY + 1;
const CHARACTERISTIC_TWO_ADICITY: usize = F::CHARACTERISTIC_TWO_ADICITY;
const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR);
const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR);
const BITS: usize = F::BITS * 2;
fn order() -> BigUint {
F::order() * F::order()
}
fn characteristic() -> BigUint {
F::characteristic()
}
#[inline(always)]
fn square(&self) -> Self {
@ -99,6 +102,15 @@ impl<F: Extendable<2>> Field for QuadraticExtension<F> {
))
}
fn from_biguint(n: BigUint) -> Self {
let (high, low) = n.div_rem(&F::order());
Self([F::from_biguint(low), F::from_biguint(high)])
}
fn to_biguint(&self) -> BigUint {
self.0[0].to_biguint() + F::order() * self.0[1].to_biguint()
}
fn from_canonical_u64(n: u64) -> Self {
F::from_canonical_u64(n).into()
}

View File

@ -4,6 +4,7 @@ use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssi
use num::bigint::BigUint;
use num::traits::Pow;
use num::Integer;
use rand::Rng;
use serde::{Deserialize, Serialize};
@ -50,27 +51,29 @@ impl<F: Extendable<4>> From<F> for QuarticExtension<F> {
}
impl<F: Extendable<4>> Field for QuarticExtension<F> {
type PrimeField = F;
const ZERO: Self = Self([F::ZERO; 4]);
const ONE: Self = Self([F::ONE, F::ZERO, F::ZERO, F::ZERO]);
const TWO: Self = Self([F::TWO, F::ZERO, F::ZERO, F::ZERO]);
const NEG_ONE: Self = Self([F::NEG_ONE, F::ZERO, F::ZERO, F::ZERO]);
const CHARACTERISTIC: u64 = F::ORDER;
// `p^4 - 1 = (p - 1)(p + 1)(p^2 + 1)`. The `p - 1` term has a two-adicity of `F::TWO_ADICITY`.
// As long as `F::TWO_ADICITY >= 2`, `p` can be written as `4n + 1`, so `p + 1` can be written as
// `2(2n + 1)`, which has a 2-adicity of 1. A similar argument can show that `p^2 + 1` also has
// a 2-adicity of 1.
const TWO_ADICITY: usize = F::TWO_ADICITY + 2;
const CHARACTERISTIC_TWO_ADICITY: usize = F::CHARACTERISTIC_TWO_ADICITY;
const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR);
const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR);
const BITS: usize = F::BITS * 4;
fn order() -> BigUint {
F::order().pow(4u32)
}
fn characteristic() -> BigUint {
F::characteristic()
}
#[inline(always)]
fn square(&self) -> Self {
@ -104,6 +107,26 @@ impl<F: Extendable<4>> Field for QuarticExtension<F> {
))
}
fn from_biguint(n: BigUint) -> Self {
let (rest, first) = n.div_rem(&F::order());
let (rest, second) = rest.div_rem(&F::order());
let (rest, third) = rest.div_rem(&F::order());
Self([
F::from_biguint(first),
F::from_biguint(second),
F::from_biguint(third),
F::from_biguint(rest),
])
}
fn to_biguint(&self) -> BigUint {
let mut result = self.0[3].to_biguint();
result = result * F::order() + self.0[2].to_biguint();
result = result * F::order() + self.0[1].to_biguint();
result = result * F::order() + self.0[0].to_biguint();
result
}
fn from_canonical_u64(n: u64) -> Self {
F::from_canonical_u64(n).into()
}

View File

@ -1,4 +1,3 @@
use std::convert::{TryFrom, TryInto};
use std::ops::Range;
use crate::field::extension_field::algebra::ExtensionAlgebra;
@ -33,6 +32,7 @@ impl<const D: usize> ExtensionTarget<D> {
let arr = self.to_target_array();
let k = (F::order() - 1u32) / (D as u64);
let z0 = F::Extension::W.exp_biguint(&(k * count as u64));
#[allow(clippy::needless_collect)]
let zs = z0
.powers()
.take(D)

View File

@ -5,8 +5,8 @@ use unroll::unroll_for_loops;
use crate::field::field_types::Field;
use crate::field::packable::Packable;
use crate::field::packed_field::{PackedField, Singleton};
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::field::packed_field::PackedField;
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::util::{log2_strict, reverse_index_bits};
pub(crate) type FftRootTable<F> = Vec<Vec<F>>;
@ -38,7 +38,7 @@ fn fft_dispatch<F: Field>(
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) -> Vec<F> {
let computed_root_table = if let Some(_) = root_table {
let computed_root_table = if root_table.is_some() {
None
} else {
Some(fft_root_table(input.len()))
@ -98,12 +98,12 @@ pub fn ifft_with_options<F: Field>(
/// Generic FFT implementation that works with both scalar and packed inputs.
#[unroll_for_loops]
fn fft_classic_simd<P: PackedField>(
values: &mut [P::FieldType],
values: &mut [P::Scalar],
r: usize,
lg_n: usize,
root_table: &FftRootTable<P::FieldType>,
root_table: &FftRootTable<P::Scalar>,
) {
let lg_packed_width = P::LOG2_WIDTH; // 0 when P is a scalar.
let lg_packed_width = log2_strict(P::WIDTH); // 0 when P is a scalar.
let packed_values = P::pack_slice_mut(values);
let packed_n = packed_values.len();
debug_assert!(packed_n == 1 << (lg_n - lg_packed_width));
@ -121,19 +121,18 @@ fn fft_classic_simd<P: PackedField>(
let half_m = 1 << lg_half_m;
// Set omega to root_table[lg_half_m][0..half_m] but repeated.
let mut omega_vec = P::zero().to_vec();
for j in 0..omega_vec.len() {
omega_vec[j] = root_table[lg_half_m][j % half_m];
let mut omega = P::ZERO;
for (j, omega_j) in omega.as_slice_mut().iter_mut().enumerate() {
*omega_j = root_table[lg_half_m][j % half_m];
}
let omega = P::from_slice(&omega_vec[..]);
for k in (0..packed_n).step_by(2) {
// We have two vectors and want to do math on pairs of adjacent elements (or for
// lg_half_m > 0, pairs of adjacent blocks of elements). .interleave does the
// appropriate shuffling and is its own inverse.
let (u, v) = packed_values[k].interleave(packed_values[k + 1], lg_half_m);
let (u, v) = packed_values[k].interleave(packed_values[k + 1], half_m);
let t = omega * v;
(packed_values[k], packed_values[k + 1]) = (u + t).interleave(u - t, lg_half_m);
(packed_values[k], packed_values[k + 1]) = (u + t).interleave(u - t, half_m);
}
}
}
@ -197,13 +196,13 @@ pub(crate) fn fft_classic<F: Field>(input: &[F], r: usize, root_table: &FftRootT
}
}
let lg_packed_width = <F as Packable>::PackedType::LOG2_WIDTH;
let lg_packed_width = log2_strict(<F as Packable>::Packing::WIDTH);
if lg_n <= lg_packed_width {
// Need the slice to be at least the width of two packed vectors for the vectorized version
// to work. Do this tiny problem in scalar.
fft_classic_simd::<Singleton<F>>(&mut values[..], r, lg_n, &root_table);
fft_classic_simd::<F>(&mut values[..], r, lg_n, root_table);
} else {
fft_classic_simd::<<F as Packable>::PackedType>(&mut values[..], r, lg_n, &root_table);
fft_classic_simd::<<F as Packable>::Packing>(&mut values[..], r, lg_n, root_table);
}
values
}
@ -213,19 +212,23 @@ mod tests {
use crate::field::fft::{fft, fft_with_options, ifft};
use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField;
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::util::{log2_ceil, log2_strict};
#[test]
fn fft_and_ifft() {
type F = GoldilocksField;
let degree = 200;
let degree_padded = log2_ceil(degree);
let mut coefficients = Vec::new();
for i in 0..degree {
coefficients.push(F::from_canonical_usize(i * 1337 % 100));
}
let coefficients = PolynomialCoeffs::new_padded(coefficients);
let degree = 200usize;
let degree_padded = degree.next_power_of_two();
// Create a vector of coeffs; the first degree of them are
// "random", the last degree_padded-degree of them are zero.
let coeffs = (0..degree)
.map(|i| F::from_canonical_usize(i * 1337 % 100))
.chain(std::iter::repeat(F::ZERO).take(degree_padded - degree))
.collect::<Vec<_>>();
assert_eq!(coeffs.len(), degree_padded);
let coefficients = PolynomialCoeffs { coeffs };
let points = fft(&coefficients);
assert_eq!(points, evaluate_naive(&coefficients));
@ -263,7 +266,7 @@ mod tests {
let values = subgroup
.into_iter()
.map(|x| evaluate_at_naive(&coefficients, x))
.map(|x| evaluate_at_naive(coefficients, x))
.collect();
PolynomialValues::new(values)
}
@ -272,8 +275,8 @@ mod tests {
let mut sum = F::ZERO;
let mut point_power = F::ONE;
for &c in &coefficients.coeffs {
sum = sum + c * point_power;
point_power = point_power * point;
sum += c * point_power;
point_power *= point;
}
sum
}

View File

@ -13,12 +13,15 @@ macro_rules! test_field_arithmetic {
#[test]
fn batch_inversion() {
let xs = (1..=3)
.map(|i| <$field>::from_canonical_u64(i))
.collect::<Vec<_>>();
let invs = <$field>::batch_multiplicative_inverse(&xs);
for (x, inv) in xs.into_iter().zip(invs) {
assert_eq!(x * inv, <$field>::ONE);
for n in 0..20 {
let xs = (1..=n as u64)
.map(|i| <$field>::from_canonical_u64(i))
.collect::<Vec<_>>();
let invs = <$field>::batch_multiplicative_inverse(&xs);
assert_eq!(invs.len(), n);
for (x, inv) in xs.into_iter().zip(invs) {
assert_eq!(x * inv, <$field>::ONE);
}
}
}
@ -81,10 +84,24 @@ macro_rules! test_field_arithmetic {
assert_eq!(base.exp_biguint(&pow), base.exp_biguint(&big_pow));
assert_ne!(base.exp_biguint(&pow), base.exp_biguint(&big_pow_wrong));
}
#[test]
fn inverses() {
type F = $field;
let x = F::rand();
let x1 = x.inverse();
let x2 = x1.inverse();
let x3 = x2.inverse();
assert_eq!(x, x2);
assert_eq!(x1, x3);
}
}
};
}
#[allow(clippy::eq_op)]
pub(crate) fn test_add_neg_sub_mul<BF: Extendable<D>, const D: usize>() {
let x = BF::Extension::rand();
let y = BF::Extension::rand();

View File

@ -1,11 +1,10 @@
use std::convert::TryInto;
use std::fmt::{Debug, Display};
use std::hash::Hash;
use std::iter::{Product, Sum};
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use num::bigint::BigUint;
use num::{Integer, One, Zero};
use num::{Integer, One, ToPrimitive, Zero};
use rand::Rng;
use serde::de::DeserializeOwned;
use serde::Serialize;
@ -43,24 +42,28 @@ pub trait Field:
+ Serialize
+ DeserializeOwned
{
type PrimeField: PrimeField;
const ZERO: Self;
const ONE: Self;
const TWO: Self;
const NEG_ONE: Self;
const CHARACTERISTIC: u64;
/// The 2-adicity of this field's multiplicative group.
const TWO_ADICITY: usize;
/// The field's characteristic and it's 2-adicity.
/// Set to `None` when the characteristic doesn't fit in a u64.
const CHARACTERISTIC_TWO_ADICITY: usize;
/// Generator of the entire multiplicative group, i.e. all non-zero elements.
const MULTIPLICATIVE_GROUP_GENERATOR: Self;
/// Generator of a multiplicative subgroup of order `2^TWO_ADICITY`.
const POWER_OF_TWO_GENERATOR: Self;
/// The bit length of the field order.
const BITS: usize;
fn order() -> BigUint;
fn characteristic() -> BigUint;
#[inline]
fn is_zero(&self) -> bool {
@ -92,6 +95,10 @@ pub trait Field:
self.square() * *self
}
fn triple(&self) -> Self {
*self * (Self::ONE + Self::TWO)
}
/// Compute the multiplicative inverse of this field element.
fn try_inverse(&self) -> Option<Self>;
@ -103,34 +110,91 @@ pub trait Field:
// This is Montgomery's trick. At a high level, we invert the product of the given field
// elements, then derive the individual inverses from that via multiplication.
// The usual Montgomery trick involves calculating an array of cumulative products,
// resulting in a long dependency chain. To increase instruction-level parallelism, we
// compute WIDTH separate cumulative product arrays that only meet at the end.
// Higher WIDTH increases instruction-level parallelism, but too high a value will cause us
// to run out of registers.
const WIDTH: usize = 4;
// JN note: WIDTH is 4. The code is specialized to this value and will need
// modification if it is changed. I tried to make it more generic, but Rust's const
// generics are not yet good enough.
// Handle special cases. Paradoxically, below is repetitive but concise.
// The branches should be very predictable.
let n = x.len();
if n == 0 {
return Vec::new();
}
if n == 1 {
} else if n == 1 {
return vec![x[0].inverse()];
} else if n == 2 {
let x01 = x[0] * x[1];
let x01inv = x01.inverse();
return vec![x01inv * x[1], x01inv * x[0]];
} else if n == 3 {
let x01 = x[0] * x[1];
let x012 = x01 * x[2];
let x012inv = x012.inverse();
let x01inv = x012inv * x[2];
return vec![x01inv * x[1], x01inv * x[0], x012inv * x01];
}
debug_assert!(n >= WIDTH);
// Fill buf with cumulative product of x.
let mut buf = Vec::with_capacity(n);
let mut cumul_prod = x[0];
buf.push(cumul_prod);
for i in 1..n {
cumul_prod *= x[i];
buf.push(cumul_prod);
// Buf is reused for a few things to save allocations.
// Fill buf with cumulative product of x, only taking every 4th value. Concretely, buf will
// be [
// x[0], x[1], x[2], x[3],
// x[0] * x[4], x[1] * x[5], x[2] * x[6], x[3] * x[7],
// x[0] * x[4] * x[8], x[1] * x[5] * x[9], x[2] * x[6] * x[10], x[3] * x[7] * x[11],
// ...
// ].
// If n is not a multiple of WIDTH, the result is truncated from the end. For example,
// for n == 5, we get [x[0], x[1], x[2], x[3], x[0] * x[4]].
let mut buf: Vec<Self> = Vec::with_capacity(n);
// cumul_prod holds the last WIDTH elements of buf. This is redundant, but it's how we
// convince LLVM to keep the values in the registers.
let mut cumul_prod: [Self; WIDTH] = x[..WIDTH].try_into().unwrap();
buf.extend(cumul_prod);
for (i, &xi) in x[WIDTH..].iter().enumerate() {
cumul_prod[i % WIDTH] *= xi;
buf.push(cumul_prod[i % WIDTH]);
}
debug_assert_eq!(buf.len(), n);
// At this stage buf contains the the cumulative product of x. We reuse the buffer for
// efficiency. At the end of the loop, it is filled with inverses of x.
let mut a_inv = cumul_prod.inverse();
buf[n - 1] = buf[n - 2] * a_inv;
for i in (1..n - 1).rev() {
a_inv = x[i + 1] * a_inv;
// buf[i - 1] has not been written to by this loop, so it equals x[0] * ... x[n - 1].
buf[i] = buf[i - 1] * a_inv;
let mut a_inv = {
// This is where the four dependency chains meet.
// Take the last four elements of buf and invert them all.
let c01 = cumul_prod[0] * cumul_prod[1];
let c23 = cumul_prod[2] * cumul_prod[3];
let c0123 = c01 * c23;
let c0123inv = c0123.inverse();
let c01inv = c0123inv * c23;
let c23inv = c0123inv * c01;
[
c01inv * cumul_prod[1],
c01inv * cumul_prod[0],
c23inv * cumul_prod[3],
c23inv * cumul_prod[2],
]
};
for i in (WIDTH..n).rev() {
// buf[i - WIDTH] has not been written to by this loop, so it equals
// x[i % WIDTH] * x[i % WIDTH + WIDTH] * ... * x[i - WIDTH].
buf[i] = buf[i - WIDTH] * a_inv[i % WIDTH];
// buf[i] now holds the inverse of x[i].
a_inv[i % WIDTH] *= x[i];
}
buf[0] = x[1] * a_inv;
for i in (0..WIDTH).rev() {
buf[i] = a_inv[i];
}
for (&bi, &xi) in buf.iter().zip(x) {
// Sanity check only.
debug_assert_eq!(bi * xi, Self::ONE);
}
buf
}
@ -142,29 +206,31 @@ pub trait Field:
// exp exceeds t, we repeatedly multiply by 2^-t and reduce
// exp until it's in the right range.
let p = Self::CHARACTERISTIC;
if let Some(p) = Self::characteristic().to_u64() {
// NB: The only reason this is split into two cases is to save
// the multiplication (and possible calculation of
// inverse_2_pow_adicity) in the usual case that exp <=
// TWO_ADICITY. Can remove the branch and simplify if that
// saving isn't worth it.
// NB: The only reason this is split into two cases is to save
// the multiplication (and possible calculation of
// inverse_2_pow_adicity) in the usual case that exp <=
// TWO_ADICITY. Can remove the branch and simplify if that
// saving isn't worth it.
if exp > Self::CHARACTERISTIC_TWO_ADICITY {
// NB: This should be a compile-time constant
let inverse_2_pow_adicity: Self =
Self::from_canonical_u64(p - ((p - 1) >> Self::CHARACTERISTIC_TWO_ADICITY));
if exp > Self::PrimeField::TWO_ADICITY {
// NB: This should be a compile-time constant
let inverse_2_pow_adicity: Self =
Self::from_canonical_u64(p - ((p - 1) >> Self::PrimeField::TWO_ADICITY));
let mut res = inverse_2_pow_adicity;
let mut e = exp - Self::CHARACTERISTIC_TWO_ADICITY;
let mut res = inverse_2_pow_adicity;
let mut e = exp - Self::PrimeField::TWO_ADICITY;
while e > Self::PrimeField::TWO_ADICITY {
res *= inverse_2_pow_adicity;
e -= Self::PrimeField::TWO_ADICITY;
while e > Self::CHARACTERISTIC_TWO_ADICITY {
res *= inverse_2_pow_adicity;
e -= Self::CHARACTERISTIC_TWO_ADICITY;
}
res * Self::from_canonical_u64(p - ((p - 1) >> e))
} else {
Self::from_canonical_u64(p - ((p - 1) >> exp))
}
res * Self::from_canonical_u64(p - ((p - 1) >> e))
} else {
Self::from_canonical_u64(p - ((p - 1) >> exp))
Self::TWO.inverse().exp_u64(exp as u64)
}
}
@ -206,6 +272,11 @@ pub trait Field:
subgroup.into_iter().map(|x| x * shift).collect()
}
// TODO: move these to a new `PrimeField` trait (for all prime fields, not just 64-bit ones)
fn from_biguint(n: BigUint) -> Self;
fn to_biguint(&self) -> BigUint;
fn from_canonical_u64(n: u64) -> Self;
fn from_canonical_u32(n: u32) -> Self {
@ -274,7 +345,7 @@ pub trait Field:
}
fn kth_root_u64(&self, k: u64) -> Self {
let p = Self::order().clone();
let p = Self::order();
let p_minus_1 = &p - 1u32;
debug_assert!(
Self::is_monomial_permutation_u64(k),
@ -356,6 +427,7 @@ pub trait PrimeField: Field {
unsafe { self.sub_canonical_u64(1) }
}
/// # Safety
/// Equivalent to *self + Self::from_canonical_u64(rhs), but may be cheaper. The caller must
/// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this
/// precondition is not met. It is marked unsafe for this reason.
@ -365,6 +437,7 @@ pub trait PrimeField: Field {
*self + Self::from_canonical_u64(rhs)
}
/// # Safety
/// Equivalent to *self - Self::from_canonical_u64(rhs), but may be cheaper. The caller must
/// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this
/// precondition is not met. It is marked unsafe for this reason.

View File

@ -4,7 +4,7 @@ use std::hash::{Hash, Hasher};
use std::iter::{Product, Sum};
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use num::BigUint;
use num::{BigUint, Integer};
use rand::Rng;
use serde::{Deserialize, Serialize};
@ -62,15 +62,13 @@ impl Debug for GoldilocksField {
}
impl Field for GoldilocksField {
type PrimeField = Self;
const ZERO: Self = Self(0);
const ONE: Self = Self(1);
const TWO: Self = Self(2);
const NEG_ONE: Self = Self(Self::ORDER - 1);
const CHARACTERISTIC: u64 = Self::ORDER;
const TWO_ADICITY: usize = 32;
const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY;
// Sage: `g = GF(p).multiplicative_generator()`
const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(7);
@ -82,15 +80,28 @@ impl Field for GoldilocksField {
// ```
const POWER_OF_TWO_GENERATOR: Self = Self(1753635133440165772);
const BITS: usize = 64;
fn order() -> BigUint {
Self::ORDER.into()
}
fn characteristic() -> BigUint {
Self::order()
}
#[inline(always)]
fn try_inverse(&self) -> Option<Self> {
try_inverse_u64(self)
}
fn from_biguint(n: BigUint) -> Self {
Self(n.mod_floor(&Self::order()).to_u64_digits()[0])
}
fn to_biguint(&self) -> BigUint {
self.to_canonical_u64().into()
}
#[inline]
fn from_canonical_u64(n: u64) -> Self {
debug_assert!(n < Self::ORDER);
@ -312,6 +323,7 @@ impl RichField for GoldilocksField {}
#[inline(always)]
#[cfg(target_arch = "x86_64")]
unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 {
use std::arch::asm;
let res_wrapped: u64;
let adjustment: u64;
asm!(
@ -352,6 +364,7 @@ unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 {
#[inline(always)]
#[cfg(target_arch = "x86_64")]
unsafe fn sub_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 {
use std::arch::asm;
let res_wrapped: u64;
let adjustment: u64;
asm!(

View File

@ -1,6 +1,6 @@
use crate::field::fft::ifft;
use crate::field::field_types::Field;
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::util::log2_ceil;
/// Computes the unique degree < n interpolant of an arbitrary list of n (point, value) pairs.
@ -80,7 +80,7 @@ mod tests {
use crate::field::extension_field::quartic::QuarticExtension;
use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField;
use crate::polynomial::polynomial::PolynomialCoeffs;
use crate::polynomial::PolynomialCoeffs;
#[test]
fn interpolant_random() {

View File

@ -7,7 +7,8 @@ pub(crate) mod interpolation;
mod inversion;
pub(crate) mod packable;
pub(crate) mod packed_field;
pub mod secp256k1;
pub mod secp256k1_base;
pub mod secp256k1_scalar;
#[cfg(target_feature = "avx2")]
pub(crate) mod packed_avx2;

View File

@ -1,18 +1,18 @@
use crate::field::field_types::Field;
use crate::field::packed_field::{PackedField, Singleton};
use crate::field::packed_field::PackedField;
/// Points us to the default packing for a particular field. There may me multiple choices of
/// PackedField for a particular Field (e.g. Singleton works for all fields), but this is the
/// PackedField for a particular Field (e.g. every Field is also a PackedField), but this is the
/// recommended one. The recommended packing varies by target_arch and target_feature.
pub trait Packable: Field {
type PackedType: PackedField<FieldType = Self>;
type Packing: PackedField<Scalar = Self>;
}
impl<F: Field> Packable for F {
default type PackedType = Singleton<Self>;
default type Packing = Self;
}
#[cfg(target_feature = "avx2")]
impl Packable for crate::field::goldilocks_field::GoldilocksField {
type PackedType = crate::field::packed_avx2::PackedGoldilocksAVX2;
type Packing = crate::field::packed_avx2::PackedGoldilocksAvx2;
}

View File

@ -2,20 +2,20 @@ use core::arch::x86_64::*;
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::iter::{Product, Sum};
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use crate::field::field_types::PrimeField;
use crate::field::packed_avx2::common::{
add_no_canonicalize_64_64s_s, epsilon, field_order, ReducibleAVX2,
add_no_canonicalize_64_64s_s, epsilon, field_order, shift, ReducibleAvx2,
};
use crate::field::packed_field::PackedField;
// PackedPrimeField wraps an array of four u64s, with the new and get methods to convert that
// Avx2PrimeField wraps an array of four u64s, with the new and get methods to convert that
// array to and from __m256i, which is the type we actually operate on. This indirection is a
// terrible trick to change PackedPrimeField's alignment.
// We'd like to be able to cast slices of PrimeField to slices of PackedPrimeField. Rust
// terrible trick to change Avx2PrimeField's alignment.
// We'd like to be able to cast slices of PrimeField to slices of Avx2PrimeField. Rust
// aligns __m256i to 32 bytes but PrimeField has a lower alignment. That alignment extends to
// PackedPrimeField and it appears that it cannot be lowered with #[repr(C, blah)]. It is
// Avx2PrimeField and it appears that it cannot be lowered with #[repr(C, blah)]. It is
// important for Rust not to assume 32-byte alignment, so we cannot wrap __m256i directly.
// There are two versions of vectorized load/store instructions on x86: aligned (vmovaps and
// friends) and unaligned (vmovups etc.). The difference between them is that aligned loads and
@ -23,12 +23,12 @@ use crate::field::packed_field::PackedField;
// were faster, and although this is no longer the case, compilers prefer the aligned versions if
// they know that the address is aligned. Using aligned instructions on unaligned addresses leads to
// bugs that can be frustrating to diagnose. Hence, we can't have Rust assuming alignment, and
// therefore PackedPrimeField wraps [F; 4] and not __m256i.
// therefore Avx2PrimeField wraps [F; 4] and not __m256i.
#[derive(Copy, Clone)]
#[repr(transparent)]
pub struct PackedPrimeField<F: ReducibleAVX2>(pub [F; 4]);
pub struct Avx2PrimeField<F: ReducibleAvx2>(pub [F; 4]);
impl<F: ReducibleAVX2> PackedPrimeField<F> {
impl<F: ReducibleAvx2> Avx2PrimeField<F> {
#[inline]
fn new(x: __m256i) -> Self {
let mut obj = Self([F::ZERO; 4]);
@ -43,84 +43,111 @@ impl<F: ReducibleAVX2> PackedPrimeField<F> {
let ptr = (&self.0).as_ptr().cast::<__m256i>();
unsafe { _mm256_loadu_si256(ptr) }
}
/// Addition that assumes x + y < 2^64 + F::ORDER. May return incorrect results if this
/// condition is not met, hence it is marked unsafe.
#[inline]
pub unsafe fn add_canonical_u64(&self, rhs: __m256i) -> Self {
Self::new(add_canonical_u64::<F>(self.get(), rhs))
}
}
impl<F: ReducibleAVX2> Add<Self> for PackedPrimeField<F> {
impl<F: ReducibleAvx2> Add<Self> for Avx2PrimeField<F> {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self {
Self::new(unsafe { add::<F>(self.get(), rhs.get()) })
}
}
impl<F: ReducibleAVX2> Add<F> for PackedPrimeField<F> {
impl<F: ReducibleAvx2> Add<F> for Avx2PrimeField<F> {
type Output = Self;
#[inline]
fn add(self, rhs: F) -> Self {
self + Self::broadcast(rhs)
self + Self::from(rhs)
}
}
impl<F: ReducibleAVX2> AddAssign<Self> for PackedPrimeField<F> {
impl<F: ReducibleAvx2> Add<Avx2PrimeField<F>> for <Avx2PrimeField<F> as PackedField>::Scalar {
type Output = Avx2PrimeField<F>;
#[inline]
fn add(self, rhs: Self::Output) -> Self::Output {
Self::Output::from(self) + rhs
}
}
impl<F: ReducibleAvx2> AddAssign<Self> for Avx2PrimeField<F> {
#[inline]
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl<F: ReducibleAVX2> AddAssign<F> for PackedPrimeField<F> {
impl<F: ReducibleAvx2> AddAssign<F> for Avx2PrimeField<F> {
#[inline]
fn add_assign(&mut self, rhs: F) {
*self = *self + rhs;
}
}
impl<F: ReducibleAVX2> Debug for PackedPrimeField<F> {
impl<F: ReducibleAvx2> Debug for Avx2PrimeField<F> {
#[inline]
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "({:?})", self.get())
}
}
impl<F: ReducibleAVX2> Default for PackedPrimeField<F> {
impl<F: ReducibleAvx2> Default for Avx2PrimeField<F> {
#[inline]
fn default() -> Self {
Self::zero()
Self::ZERO
}
}
impl<F: ReducibleAVX2> Mul<Self> for PackedPrimeField<F> {
impl<F: ReducibleAvx2> Div<F> for Avx2PrimeField<F> {
type Output = Self;
#[inline]
fn div(self, rhs: F) -> Self {
self * rhs.inverse()
}
}
impl<F: ReducibleAvx2> DivAssign<F> for Avx2PrimeField<F> {
#[inline]
fn div_assign(&mut self, rhs: F) {
*self *= rhs.inverse();
}
}
impl<F: ReducibleAvx2> From<F> for Avx2PrimeField<F> {
fn from(x: F) -> Self {
Self([x; 4])
}
}
impl<F: ReducibleAvx2> Mul<Self> for Avx2PrimeField<F> {
type Output = Self;
#[inline]
fn mul(self, rhs: Self) -> Self {
Self::new(unsafe { mul::<F>(self.get(), rhs.get()) })
}
}
impl<F: ReducibleAVX2> Mul<F> for PackedPrimeField<F> {
impl<F: ReducibleAvx2> Mul<F> for Avx2PrimeField<F> {
type Output = Self;
#[inline]
fn mul(self, rhs: F) -> Self {
self * Self::broadcast(rhs)
self * Self::from(rhs)
}
}
impl<F: ReducibleAVX2> MulAssign<Self> for PackedPrimeField<F> {
impl<F: ReducibleAvx2> Mul<Avx2PrimeField<F>> for <Avx2PrimeField<F> as PackedField>::Scalar {
type Output = Avx2PrimeField<F>;
#[inline]
fn mul(self, rhs: Avx2PrimeField<F>) -> Self::Output {
Self::Output::from(self) * rhs
}
}
impl<F: ReducibleAvx2> MulAssign<Self> for Avx2PrimeField<F> {
#[inline]
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl<F: ReducibleAVX2> MulAssign<F> for PackedPrimeField<F> {
impl<F: ReducibleAvx2> MulAssign<F> for Avx2PrimeField<F> {
#[inline]
fn mul_assign(&mut self, rhs: F) {
*self = *self * rhs;
}
}
impl<F: ReducibleAVX2> Neg for PackedPrimeField<F> {
impl<F: ReducibleAvx2> Neg for Avx2PrimeField<F> {
type Output = Self;
#[inline]
fn neg(self) -> Self {
@ -128,52 +155,59 @@ impl<F: ReducibleAVX2> Neg for PackedPrimeField<F> {
}
}
impl<F: ReducibleAVX2> Product for PackedPrimeField<F> {
impl<F: ReducibleAvx2> Product for Avx2PrimeField<F> {
#[inline]
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.reduce(|x, y| x * y).unwrap_or(Self::one())
iter.reduce(|x, y| x * y).unwrap_or(Self::ONE)
}
}
impl<F: ReducibleAVX2> PackedField for PackedPrimeField<F> {
const LOG2_WIDTH: usize = 2;
unsafe impl<F: ReducibleAvx2> PackedField for Avx2PrimeField<F> {
const WIDTH: usize = 4;
type FieldType = F;
type Scalar = F;
type PackedPrimeField = Avx2PrimeField<F>;
const ZERO: Self = Self([F::ZERO; 4]);
const ONE: Self = Self([F::ONE; 4]);
#[inline]
fn broadcast(x: F) -> Self {
Self([x; 4])
}
#[inline]
fn from_arr(arr: [F; Self::WIDTH]) -> Self {
fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self {
Self(arr)
}
#[inline]
fn to_arr(&self) -> [F; Self::WIDTH] {
fn as_arr(&self) -> [Self::Scalar; Self::WIDTH] {
self.0
}
#[inline]
fn from_slice(slice: &[F]) -> Self {
assert!(slice.len() == 4);
Self([slice[0], slice[1], slice[2], slice[3]])
fn from_slice(slice: &[Self::Scalar]) -> &Self {
assert_eq!(slice.len(), Self::WIDTH);
unsafe { &*slice.as_ptr().cast() }
}
#[inline]
fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self {
assert_eq!(slice.len(), Self::WIDTH);
unsafe { &mut *slice.as_mut_ptr().cast() }
}
#[inline]
fn as_slice(&self) -> &[Self::Scalar] {
&self.0[..]
}
#[inline]
fn as_slice_mut(&mut self) -> &mut [Self::Scalar] {
&mut self.0[..]
}
#[inline]
fn to_vec(&self) -> Vec<F> {
self.0.into()
}
#[inline]
fn interleave(&self, other: Self, r: usize) -> (Self, Self) {
fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) {
let (v0, v1) = (self.get(), other.get());
let (res0, res1) = match r {
0 => unsafe { interleave0(v0, v1) },
let (res0, res1) = match block_len {
1 => unsafe { interleave1(v0, v1) },
2 => (v0, v1),
_ => panic!("r cannot be more than LOG2_WIDTH"),
2 => unsafe { interleave2(v0, v1) },
4 => (v0, v1),
_ => panic!("unsupported block_len"),
};
(Self::new(res0), Self::new(res1))
}
@ -184,47 +218,47 @@ impl<F: ReducibleAVX2> PackedField for PackedPrimeField<F> {
}
}
impl<F: ReducibleAVX2> Sub<Self> for PackedPrimeField<F> {
impl<F: ReducibleAvx2> Sub<Self> for Avx2PrimeField<F> {
type Output = Self;
#[inline]
fn sub(self, rhs: Self) -> Self {
Self::new(unsafe { sub::<F>(self.get(), rhs.get()) })
}
}
impl<F: ReducibleAVX2> Sub<F> for PackedPrimeField<F> {
impl<F: ReducibleAvx2> Sub<F> for Avx2PrimeField<F> {
type Output = Self;
#[inline]
fn sub(self, rhs: F) -> Self {
self - Self::broadcast(rhs)
self - Self::from(rhs)
}
}
impl<F: ReducibleAVX2> SubAssign<Self> for PackedPrimeField<F> {
impl<F: ReducibleAvx2> Sub<Avx2PrimeField<F>> for <Avx2PrimeField<F> as PackedField>::Scalar {
type Output = Avx2PrimeField<F>;
#[inline]
fn sub(self, rhs: Avx2PrimeField<F>) -> Self::Output {
Self::Output::from(self) - rhs
}
}
impl<F: ReducibleAvx2> SubAssign<Self> for Avx2PrimeField<F> {
#[inline]
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl<F: ReducibleAVX2> SubAssign<F> for PackedPrimeField<F> {
impl<F: ReducibleAvx2> SubAssign<F> for Avx2PrimeField<F> {
#[inline]
fn sub_assign(&mut self, rhs: F) {
*self = *self - rhs;
}
}
impl<F: ReducibleAVX2> Sum for PackedPrimeField<F> {
impl<F: ReducibleAvx2> Sum for Avx2PrimeField<F> {
#[inline]
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.reduce(|x, y| x + y).unwrap_or(Self::zero())
iter.reduce(|x, y| x + y).unwrap_or(Self::ZERO)
}
}
const SIGN_BIT: u64 = 1 << 63;
#[inline]
unsafe fn sign_bit() -> __m256i {
_mm256_set1_epi64x(SIGN_BIT as i64)
}
// Resources:
// 1. Intel Intrinsics Guide for explanation of each intrinsic:
// https://software.intel.com/sites/landingpage/IntrinsicsGuide/
@ -274,12 +308,6 @@ unsafe fn sign_bit() -> __m256i {
// Notice that the above 3-value addition still only requires two calls to shift, just like our
// 2-value addition.
/// Add 2^63 with overflow. Needed to emulate unsigned comparisons (see point 3. above).
#[inline]
unsafe fn shift(x: __m256i) -> __m256i {
_mm256_xor_si256(x, sign_bit())
}
/// Convert to canonical representation.
/// The argument is assumed to be shifted by 1 << 63 (i.e. x_s = x + 1<<63, where x is the field
/// value). The returned value is similarly shifted by 1 << 63 (i.e. we return y_s = y + (1<<63),
@ -293,14 +321,6 @@ unsafe fn canonicalize_s<F: PrimeField>(x_s: __m256i) -> __m256i {
_mm256_add_epi64(x_s, wrapback_amt)
}
/// Addition that assumes x + y < 2^64 + F::ORDER.
#[inline]
unsafe fn add_canonical_u64<F: PrimeField>(x: __m256i, y: __m256i) -> __m256i {
let y_s = shift(y);
let res_s = add_no_canonicalize_64_64s_s::<F>(x, y_s);
shift(res_s)
}
#[inline]
unsafe fn add<F: PrimeField>(x: __m256i, y: __m256i) -> __m256i {
let y_s = shift(y);
@ -326,78 +346,94 @@ unsafe fn neg<F: PrimeField>(y: __m256i) -> __m256i {
_mm256_sub_epi64(shift(field_order::<F>()), canonicalize_s::<F>(y_s))
}
/// Full 64-bit by 64-bit multiplication. This emulated multiplication is 1.5x slower than the
/// Full 64-bit by 64-bit multiplication. This emulated multiplication is 1.33x slower than the
/// scalar instruction, but may be worth it if we want our data to live in vector registers.
#[inline]
unsafe fn mul64_64_s(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
let x_hi = _mm256_srli_epi64(x, 32);
let y_hi = _mm256_srli_epi64(y, 32);
unsafe fn mul64_64(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
// We want to move the high 32 bits to the low position. The multiplication instruction ignores
// the high 32 bits, so it's ok to just duplicate it into the low position. This duplication can
// be done on port 5; bitshifts run on ports 0 and 1, competing with multiplication.
// This instruction is only provided for 32-bit floats, not integers. Idk why Intel makes the
// distinction; the casts are free and it guarantees that the exact bit pattern is preserved.
// Using a swizzle instruction of the wrong domain (float vs int) does not increase latency
// since Haswell.
let x_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x)));
let y_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(y)));
// All four pairwise multiplications
let mul_ll = _mm256_mul_epu32(x, y);
let mul_lh = _mm256_mul_epu32(x, y_hi);
let mul_hl = _mm256_mul_epu32(x_hi, y);
let mul_hh = _mm256_mul_epu32(x_hi, y_hi);
let res_lo0_s = shift(mul_ll);
let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 32));
let res_lo2_s = _mm256_add_epi32(res_lo1_s, _mm256_slli_epi64(mul_hl, 32));
// Bignum addition
// Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow.
let mul_ll_hi = _mm256_srli_epi64::<32>(mul_ll);
let t0 = _mm256_add_epi64(mul_hl, mul_ll_hi);
// Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow.
// Also, extract high 32 bits of t0 and add to mul_hh.
let t0_lo = _mm256_and_si256(t0, _mm256_set1_epi64x(u32::MAX.into()));
let t0_hi = _mm256_srli_epi64::<32>(t0);
let t1 = _mm256_add_epi64(mul_lh, t0_lo);
let t2 = _mm256_add_epi64(mul_hh, t0_hi);
// Lastly, extract the high 32 bits of t1 and add to t2.
let t1_hi = _mm256_srli_epi64::<32>(t1);
let res_hi = _mm256_add_epi64(t2, t1_hi);
// cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on
// overflow and must be subtracted, not added.
let carry0 = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s);
let carry1 = _mm256_cmpgt_epi64(res_lo1_s, res_lo2_s);
// Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high
// position).
let t1_lo = _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(t1)));
let res_lo = _mm256_blend_epi32::<0xaa>(mul_ll, t1_lo);
let res_hi0 = mul_hh;
let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 32));
let res_hi2 = _mm256_add_epi64(res_hi1, _mm256_srli_epi64(mul_hl, 32));
let res_hi3 = _mm256_sub_epi64(res_hi2, carry0);
let res_hi4 = _mm256_sub_epi64(res_hi3, carry1);
(res_hi4, res_lo2_s)
(res_hi, res_lo)
}
/// Full 64-bit squaring. This routine is 1.2x faster than the scalar instruction.
#[inline]
unsafe fn square64_s(x: __m256i) -> (__m256i, __m256i) {
let x_hi = _mm256_srli_epi64(x, 32);
unsafe fn square64(x: __m256i) -> (__m256i, __m256i) {
// Get high 32 bits of x. See comment in mul64_64_s.
let x_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x)));
// All pairwise multiplications.
let mul_ll = _mm256_mul_epu32(x, x);
let mul_lh = _mm256_mul_epu32(x, x_hi);
let mul_hh = _mm256_mul_epu32(x_hi, x_hi);
let res_lo0_s = shift(mul_ll);
let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 33));
// Bignum addition, but mul_lh is shifted by 33 bits (not 32).
let mul_ll_hi = _mm256_srli_epi64::<33>(mul_ll);
let t0 = _mm256_add_epi64(mul_lh, mul_ll_hi);
let t0_hi = _mm256_srli_epi64::<31>(t0);
let res_hi = _mm256_add_epi64(mul_hh, t0_hi);
// cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on
// overflow and must be subtracted, not added.
let carry = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s);
// Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high
// position).
let mul_lh_lo = _mm256_slli_epi64::<33>(mul_lh);
let res_lo = _mm256_add_epi64(mul_ll, mul_lh_lo);
let res_hi0 = mul_hh;
let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 31));
let res_hi2 = _mm256_sub_epi64(res_hi1, carry);
(res_hi2, res_lo1_s)
(res_hi, res_lo)
}
/// Multiply two integers modulo FIELD_ORDER.
#[inline]
unsafe fn mul<F: ReducibleAVX2>(x: __m256i, y: __m256i) -> __m256i {
shift(F::reduce128s_s(mul64_64_s(x, y)))
unsafe fn mul<F: ReducibleAvx2>(x: __m256i, y: __m256i) -> __m256i {
F::reduce128(mul64_64(x, y))
}
/// Square an integer modulo FIELD_ORDER.
#[inline]
unsafe fn square<F: ReducibleAVX2>(x: __m256i) -> __m256i {
shift(F::reduce128s_s(square64_s(x)))
unsafe fn square<F: ReducibleAvx2>(x: __m256i) -> __m256i {
F::reduce128(square64(x))
}
#[inline]
unsafe fn interleave0(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
unsafe fn interleave1(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
let a = _mm256_unpacklo_epi64(x, y);
let b = _mm256_unpackhi_epi64(x, y);
(a, b)
}
#[inline]
unsafe fn interleave1(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
unsafe fn interleave2(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
let y_lo = _mm256_castsi256_si128(y); // This has 0 cost.
// 1 places y_lo in the high half of x; 0 would place it in the lower half.

View File

@ -2,8 +2,22 @@ use core::arch::x86_64::*;
use crate::field::field_types::PrimeField;
pub trait ReducibleAVX2: PrimeField {
unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i;
pub trait ReducibleAvx2: PrimeField {
unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i;
}
const SIGN_BIT: u64 = 1 << 63;
#[inline]
unsafe fn sign_bit() -> __m256i {
_mm256_set1_epi64x(SIGN_BIT as i64)
}
/// Add 2^63 with overflow. Needed to emulate unsigned comparisons (see point 3. in
/// packed_prime_field.rs).
#[inline]
pub unsafe fn shift(x: __m256i) -> __m256i {
_mm256_xor_si256(x, sign_bit())
}
#[inline]

View File

@ -2,19 +2,21 @@ use core::arch::x86_64::*;
use crate::field::goldilocks_field::GoldilocksField;
use crate::field::packed_avx2::common::{
add_no_canonicalize_64_64s_s, epsilon, sub_no_canonicalize_64s_64_s, ReducibleAVX2,
add_no_canonicalize_64_64s_s, epsilon, shift, sub_no_canonicalize_64s_64_s, ReducibleAvx2,
};
/// Reduce a u128 modulo FIELD_ORDER. The input is (u64, u64), pre-shifted by 2^63. The result is
/// similarly shifted.
impl ReducibleAVX2 for GoldilocksField {
impl ReducibleAvx2 for GoldilocksField {
#[inline]
unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i {
let (hi0, lo0_s) = x_s;
unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i {
let (hi0, lo0) = x;
let lo0_s = shift(lo0);
let hi_hi0 = _mm256_srli_epi64(hi0, 32);
let lo1_s = sub_no_canonicalize_64s_64_s::<GoldilocksField>(lo0_s, hi_hi0);
let t1 = _mm256_mul_epu32(hi0, epsilon::<GoldilocksField>());
let lo2_s = add_no_canonicalize_64_64s_s::<GoldilocksField>(t1, lo1_s);
lo2_s
let lo2 = shift(lo2_s);
lo2
}
}

View File

@ -1,21 +1,21 @@
mod avx2_prime_field;
mod common;
mod goldilocks;
mod packed_prime_field;
use packed_prime_field::PackedPrimeField;
use avx2_prime_field::Avx2PrimeField;
use crate::field::goldilocks_field::GoldilocksField;
pub type PackedGoldilocksAVX2 = PackedPrimeField<GoldilocksField>;
pub type PackedGoldilocksAvx2 = Avx2PrimeField<GoldilocksField>;
#[cfg(test)]
mod tests {
use crate::field::goldilocks_field::GoldilocksField;
use crate::field::packed_avx2::common::ReducibleAVX2;
use crate::field::packed_avx2::packed_prime_field::PackedPrimeField;
use crate::field::packed_avx2::avx2_prime_field::Avx2PrimeField;
use crate::field::packed_avx2::common::ReducibleAvx2;
use crate::field::packed_field::PackedField;
fn test_vals_a<F: ReducibleAVX2>() -> [F; 4] {
fn test_vals_a<F: ReducibleAvx2>() -> [F; 4] {
[
F::from_noncanonical_u64(14479013849828404771),
F::from_noncanonical_u64(9087029921428221768),
@ -23,7 +23,7 @@ mod tests {
F::from_noncanonical_u64(5646033492608483824),
]
}
fn test_vals_b<F: ReducibleAVX2>() -> [F; 4] {
fn test_vals_b<F: ReducibleAvx2>() -> [F; 4] {
[
F::from_noncanonical_u64(17891926589593242302),
F::from_noncanonical_u64(11009798273260028228),
@ -32,17 +32,17 @@ mod tests {
]
}
fn test_add<F: ReducibleAVX2>()
fn test_add<F: ReducibleAvx2>()
where
[(); PackedPrimeField::<F>::WIDTH]: ,
[(); Avx2PrimeField::<F>::WIDTH]:,
{
let a_arr = test_vals_a::<F>();
let b_arr = test_vals_b::<F>();
let packed_a = PackedPrimeField::<F>::from_arr(a_arr);
let packed_b = PackedPrimeField::<F>::from_arr(b_arr);
let packed_a = Avx2PrimeField::<F>::from_arr(a_arr);
let packed_b = Avx2PrimeField::<F>::from_arr(b_arr);
let packed_res = packed_a + packed_b;
let arr_res = packed_res.to_arr();
let arr_res = packed_res.as_arr();
let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a + b);
for (exp, res) in expected.zip(arr_res) {
@ -50,17 +50,17 @@ mod tests {
}
}
fn test_mul<F: ReducibleAVX2>()
fn test_mul<F: ReducibleAvx2>()
where
[(); PackedPrimeField::<F>::WIDTH]: ,
[(); Avx2PrimeField::<F>::WIDTH]:,
{
let a_arr = test_vals_a::<F>();
let b_arr = test_vals_b::<F>();
let packed_a = PackedPrimeField::<F>::from_arr(a_arr);
let packed_b = PackedPrimeField::<F>::from_arr(b_arr);
let packed_a = Avx2PrimeField::<F>::from_arr(a_arr);
let packed_b = Avx2PrimeField::<F>::from_arr(b_arr);
let packed_res = packed_a * packed_b;
let arr_res = packed_res.to_arr();
let arr_res = packed_res.as_arr();
let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a * b);
for (exp, res) in expected.zip(arr_res) {
@ -68,15 +68,15 @@ mod tests {
}
}
fn test_square<F: ReducibleAVX2>()
fn test_square<F: ReducibleAvx2>()
where
[(); PackedPrimeField::<F>::WIDTH]: ,
[(); Avx2PrimeField::<F>::WIDTH]:,
{
let a_arr = test_vals_a::<F>();
let packed_a = PackedPrimeField::<F>::from_arr(a_arr);
let packed_a = Avx2PrimeField::<F>::from_arr(a_arr);
let packed_res = packed_a.square();
let arr_res = packed_res.to_arr();
let arr_res = packed_res.as_arr();
let expected = a_arr.iter().map(|&a| a.square());
for (exp, res) in expected.zip(arr_res) {
@ -84,15 +84,15 @@ mod tests {
}
}
fn test_neg<F: ReducibleAVX2>()
fn test_neg<F: ReducibleAvx2>()
where
[(); PackedPrimeField::<F>::WIDTH]: ,
[(); Avx2PrimeField::<F>::WIDTH]:,
{
let a_arr = test_vals_a::<F>();
let packed_a = PackedPrimeField::<F>::from_arr(a_arr);
let packed_a = Avx2PrimeField::<F>::from_arr(a_arr);
let packed_res = -packed_a;
let arr_res = packed_res.to_arr();
let arr_res = packed_res.as_arr();
let expected = a_arr.iter().map(|&a| -a);
for (exp, res) in expected.zip(arr_res) {
@ -100,17 +100,17 @@ mod tests {
}
}
fn test_sub<F: ReducibleAVX2>()
fn test_sub<F: ReducibleAvx2>()
where
[(); PackedPrimeField::<F>::WIDTH]: ,
[(); Avx2PrimeField::<F>::WIDTH]:,
{
let a_arr = test_vals_a::<F>();
let b_arr = test_vals_b::<F>();
let packed_a = PackedPrimeField::<F>::from_arr(a_arr);
let packed_b = PackedPrimeField::<F>::from_arr(b_arr);
let packed_a = Avx2PrimeField::<F>::from_arr(a_arr);
let packed_b = Avx2PrimeField::<F>::from_arr(b_arr);
let packed_res = packed_a - packed_b;
let arr_res = packed_res.to_arr();
let arr_res = packed_res.as_arr();
let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a - b);
for (exp, res) in expected.zip(arr_res) {
@ -118,33 +118,39 @@ mod tests {
}
}
fn test_interleave_is_involution<F: ReducibleAVX2>()
fn test_interleave_is_involution<F: ReducibleAvx2>()
where
[(); PackedPrimeField::<F>::WIDTH]: ,
[(); Avx2PrimeField::<F>::WIDTH]:,
{
let a_arr = test_vals_a::<F>();
let b_arr = test_vals_b::<F>();
let packed_a = PackedPrimeField::<F>::from_arr(a_arr);
let packed_b = PackedPrimeField::<F>::from_arr(b_arr);
let packed_a = Avx2PrimeField::<F>::from_arr(a_arr);
let packed_b = Avx2PrimeField::<F>::from_arr(b_arr);
{
// Interleave, then deinterleave.
let (x, y) = packed_a.interleave(packed_b, 0);
let (res_a, res_b) = x.interleave(y, 0);
assert_eq!(res_a.to_arr(), a_arr);
assert_eq!(res_b.to_arr(), b_arr);
}
{
let (x, y) = packed_a.interleave(packed_b, 1);
let (res_a, res_b) = x.interleave(y, 1);
assert_eq!(res_a.to_arr(), a_arr);
assert_eq!(res_b.to_arr(), b_arr);
assert_eq!(res_a.as_arr(), a_arr);
assert_eq!(res_b.as_arr(), b_arr);
}
{
let (x, y) = packed_a.interleave(packed_b, 2);
let (res_a, res_b) = x.interleave(y, 2);
assert_eq!(res_a.as_arr(), a_arr);
assert_eq!(res_b.as_arr(), b_arr);
}
{
let (x, y) = packed_a.interleave(packed_b, 4);
let (res_a, res_b) = x.interleave(y, 4);
assert_eq!(res_a.as_arr(), a_arr);
assert_eq!(res_b.as_arr(), b_arr);
}
}
fn test_interleave<F: ReducibleAVX2>()
fn test_interleave<F: ReducibleAvx2>()
where
[(); PackedPrimeField::<F>::WIDTH]: ,
[(); Avx2PrimeField::<F>::WIDTH]:,
{
let in_a: [F; 4] = [
F::from_noncanonical_u64(00),
@ -158,42 +164,47 @@ mod tests {
F::from_noncanonical_u64(12),
F::from_noncanonical_u64(13),
];
let int0_a: [F; 4] = [
let int1_a: [F; 4] = [
F::from_noncanonical_u64(00),
F::from_noncanonical_u64(10),
F::from_noncanonical_u64(02),
F::from_noncanonical_u64(12),
];
let int0_b: [F; 4] = [
let int1_b: [F; 4] = [
F::from_noncanonical_u64(01),
F::from_noncanonical_u64(11),
F::from_noncanonical_u64(03),
F::from_noncanonical_u64(13),
];
let int1_a: [F; 4] = [
let int2_a: [F; 4] = [
F::from_noncanonical_u64(00),
F::from_noncanonical_u64(01),
F::from_noncanonical_u64(10),
F::from_noncanonical_u64(11),
];
let int1_b: [F; 4] = [
let int2_b: [F; 4] = [
F::from_noncanonical_u64(02),
F::from_noncanonical_u64(03),
F::from_noncanonical_u64(12),
F::from_noncanonical_u64(13),
];
let packed_a = PackedPrimeField::<F>::from_arr(in_a);
let packed_b = PackedPrimeField::<F>::from_arr(in_b);
{
let (x0, y0) = packed_a.interleave(packed_b, 0);
assert_eq!(x0.to_arr(), int0_a);
assert_eq!(y0.to_arr(), int0_b);
}
let packed_a = Avx2PrimeField::<F>::from_arr(in_a);
let packed_b = Avx2PrimeField::<F>::from_arr(in_b);
{
let (x1, y1) = packed_a.interleave(packed_b, 1);
assert_eq!(x1.to_arr(), int1_a);
assert_eq!(y1.to_arr(), int1_b);
assert_eq!(x1.as_arr(), int1_a);
assert_eq!(y1.as_arr(), int1_b);
}
{
let (x2, y2) = packed_a.interleave(packed_b, 2);
assert_eq!(x2.as_arr(), int2_a);
assert_eq!(y2.as_arr(), int2_b);
}
{
let (x4, y4) = packed_a.interleave(packed_b, 4);
assert_eq!(x4.as_arr(), in_a);
assert_eq!(y4.as_arr(), in_b);
}
}

View File

@ -1,77 +1,81 @@
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::fmt::Debug;
use std::iter::{Product, Sum};
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use std::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign};
use std::slice;
use crate::field::field_types::Field;
pub trait PackedField:
/// # Safety
/// - WIDTH is assumed to be a power of 2.
/// - If P implements PackedField then P must be castable to/from [P::Scalar; P::WIDTH] without UB.
pub unsafe trait PackedField:
'static
+ Add<Self, Output = Self>
+ Add<Self::FieldType, Output = Self>
+ Add<Self::Scalar, Output = Self>
+ AddAssign<Self>
+ AddAssign<Self::FieldType>
+ AddAssign<Self::Scalar>
+ Copy
+ Debug
+ Default
// TODO: Implementing Div sounds like a pain so it's a worry for later.
+ From<Self::Scalar>
// TODO: Implement packed / packed division
+ Div<Self::Scalar, Output = Self>
+ Mul<Self, Output = Self>
+ Mul<Self::FieldType, Output = Self>
+ Mul<Self::Scalar, Output = Self>
+ MulAssign<Self>
+ MulAssign<Self::FieldType>
+ MulAssign<Self::Scalar>
+ Neg<Output = Self>
+ Product
+ Send
+ Sub<Self, Output = Self>
+ Sub<Self::FieldType, Output = Self>
+ Sub<Self::Scalar, Output = Self>
+ SubAssign<Self>
+ SubAssign<Self::FieldType>
+ SubAssign<Self::Scalar>
+ Sum
+ Sync
where
Self::Scalar: Add<Self, Output = Self>,
Self::Scalar: Mul<Self, Output = Self>,
Self::Scalar: Sub<Self, Output = Self>,
{
type FieldType: Field;
type Scalar: Field;
const LOG2_WIDTH: usize;
const WIDTH: usize = 1 << Self::LOG2_WIDTH;
const WIDTH: usize;
const ZERO: Self;
const ONE: Self;
fn square(&self) -> Self {
*self * *self
}
fn zero() -> Self {
Self::broadcast(Self::FieldType::ZERO)
}
fn one() -> Self {
Self::broadcast(Self::FieldType::ONE)
}
fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self;
fn as_arr(&self) -> [Self::Scalar; Self::WIDTH];
fn broadcast(x: Self::FieldType) -> Self;
fn from_slice(slice: &[Self::Scalar]) -> &Self;
fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self;
fn as_slice(&self) -> &[Self::Scalar];
fn as_slice_mut(&mut self) -> &mut [Self::Scalar];
fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self;
fn to_arr(&self) -> [Self::FieldType; Self::WIDTH];
fn from_slice(slice: &[Self::FieldType]) -> Self;
fn to_vec(&self) -> Vec<Self::FieldType>;
/// Take interpret two vectors as chunks of (1 << r) elements. Unpack and interleave those
/// Take interpret two vectors as chunks of block_len elements. Unpack and interleave those
/// chunks. This is best seen with an example. If we have:
/// A = [x0, y0, x1, y1],
/// B = [x2, y2, x3, y3],
/// then
/// interleave(A, B, 0) = ([x0, x2, x1, x3], [y0, y2, y1, y3]).
/// interleave(A, B, 1) = ([x0, x2, x1, x3], [y0, y2, y1, y3]).
/// Pairs that were adjacent in the input are at corresponding positions in the output.
/// r lets us set the size of chunks we're interleaving. If we set r = 1, then for
/// r lets us set the size of chunks we're interleaving. If we set block_len = 2, then for
/// A = [x0, x1, y0, y1],
/// B = [x2, x3, y2, y3],
/// we obtain
/// interleave(A, B, r) = ([x0, x1, x2, x3], [y0, y1, y2, y3]).
/// interleave(A, B, block_len) = ([x0, x1, x2, x3], [y0, y1, y2, y3]).
/// We can also think about this as stacking the vectors, dividing them into 2x2 matrices, and
/// transposing those matrices.
/// When r = LOG2_WIDTH, this operation is a no-op. Values of r > LOG2_WIDTH are not
/// permitted.
fn interleave(&self, other: Self, r: usize) -> (Self, Self);
/// When block_len = WIDTH, this operation is a no-op. block_len must divide WIDTH. Since
/// WIDTH is specified to be a power of 2, block_len must also be a power of 2. It cannot be 0
/// and it cannot be > WIDTH.
fn interleave(&self, other: Self, block_len: usize) -> (Self, Self);
fn pack_slice(buf: &[Self::FieldType]) -> &[Self] {
fn pack_slice(buf: &[Self::Scalar]) -> &[Self] {
assert!(
buf.len() % Self::WIDTH == 0,
"Slice length (got {}) must be a multiple of packed field width ({}).",
@ -82,7 +86,7 @@ pub trait PackedField:
let n = buf.len() / Self::WIDTH;
unsafe { std::slice::from_raw_parts(buf_ptr, n) }
}
fn pack_slice_mut(buf: &mut [Self::FieldType]) -> &mut [Self] {
fn pack_slice_mut(buf: &mut [Self::Scalar]) -> &mut [Self] {
assert!(
buf.len() % Self::WIDTH == 0,
"Slice length (got {}) must be a multiple of packed field width ({}).",
@ -95,143 +99,41 @@ pub trait PackedField:
}
}
#[derive(Copy, Clone)]
#[repr(transparent)]
pub struct Singleton<F: Field>(pub F);
unsafe impl<F: Field> PackedField for F {
type Scalar = Self;
impl<F: Field> Add<Self> for Singleton<F> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self(self.0 + rhs.0)
}
}
impl<F: Field> Add<F> for Singleton<F> {
type Output = Self;
fn add(self, rhs: F) -> Self {
self + Self::broadcast(rhs)
}
}
impl<F: Field> AddAssign<Self> for Singleton<F> {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl<F: Field> AddAssign<F> for Singleton<F> {
fn add_assign(&mut self, rhs: F) {
*self = *self + rhs;
}
}
impl<F: Field> Debug for Singleton<F> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "({:?})", self.0)
}
}
impl<F: Field> Default for Singleton<F> {
fn default() -> Self {
Self::zero()
}
}
impl<F: Field> Mul<Self> for Singleton<F> {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
Self(self.0 * rhs.0)
}
}
impl<F: Field> Mul<F> for Singleton<F> {
type Output = Self;
fn mul(self, rhs: F) -> Self {
self * Self::broadcast(rhs)
}
}
impl<F: Field> MulAssign<Self> for Singleton<F> {
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl<F: Field> MulAssign<F> for Singleton<F> {
fn mul_assign(&mut self, rhs: F) {
*self = *self * rhs;
}
}
impl<F: Field> Neg for Singleton<F> {
type Output = Self;
fn neg(self) -> Self {
Self(-self.0)
}
}
impl<F: Field> Product for Singleton<F> {
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
Self(iter.map(|x| x.0).product())
}
}
impl<F: Field> PackedField for Singleton<F> {
const LOG2_WIDTH: usize = 0;
type FieldType = F;
fn broadcast(x: F) -> Self {
Self(x)
}
fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self {
Self(arr[0])
}
fn to_arr(&self) -> [Self::FieldType; Self::WIDTH] {
[self.0]
}
fn from_slice(slice: &[Self::FieldType]) -> Self {
assert!(slice.len() == 1);
Self(slice[0])
}
fn to_vec(&self) -> Vec<Self::FieldType> {
vec![self.0]
}
fn interleave(&self, other: Self, r: usize) -> (Self, Self) {
match r {
0 => (*self, other), // This is a no-op whenever r == LOG2_WIDTH.
_ => panic!("r cannot be more than LOG2_WIDTH"),
}
}
const WIDTH: usize = 1;
const ZERO: Self = <F as Field>::ZERO;
const ONE: Self = <F as Field>::ONE;
fn square(&self) -> Self {
Self(self.0.square())
<Self as Field>::square(self)
}
}
impl<F: Field> Sub<Self> for Singleton<F> {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
Self(self.0 - rhs.0)
fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self {
arr[0]
}
}
impl<F: Field> Sub<F> for Singleton<F> {
type Output = Self;
fn sub(self, rhs: F) -> Self {
self - Self::broadcast(rhs)
fn as_arr(&self) -> [Self::Scalar; Self::WIDTH] {
[*self]
}
}
impl<F: Field> SubAssign<Self> for Singleton<F> {
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl<F: Field> SubAssign<F> for Singleton<F> {
fn sub_assign(&mut self, rhs: F) {
*self = *self - rhs;
}
}
impl<F: Field> Sum for Singleton<F> {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
Self(iter.map(|x| x.0).sum())
fn from_slice(slice: &[Self::Scalar]) -> &Self {
&slice[0]
}
fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self {
&mut slice[0]
}
fn as_slice(&self) -> &[Self::Scalar] {
slice::from_ref(self)
}
fn as_slice_mut(&mut self) -> &mut [Self::Scalar] {
slice::from_mut(self)
}
fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) {
match block_len {
1 => (*self, other),
_ => panic!("unsupported block length"),
}
}
}

View File

@ -24,7 +24,7 @@ where
ExpectedOp: Fn(u64) -> u64,
{
let inputs = test_inputs(F::ORDER);
let expected: Vec<_> = inputs.iter().map(|x| expected_op(x.clone())).collect();
let expected: Vec<_> = inputs.iter().map(|&x| expected_op(x)).collect();
let output: Vec<_> = inputs
.iter()
.cloned()
@ -144,7 +144,7 @@ macro_rules! test_prime_field_arithmetic {
fn inverse_2exp() {
type F = $field;
let v = <F as Field>::PrimeField::TWO_ADICITY;
let v = <F as Field>::TWO_ADICITY;
for e in [0, 1, 2, 3, 4, v - 2, v - 1, v, v + 1, v + 2, 123 * v] {
let x = F::TWO.exp_u64(e as u64);

View File

@ -1,4 +1,3 @@
use std::convert::TryInto;
use std::fmt;
use std::fmt::{Debug, Display, Formatter};
use std::hash::{Hash, Hasher};
@ -12,7 +11,6 @@ use rand::Rng;
use serde::{Deserialize, Serialize};
use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField;
/// The base field of the secp256k1 elliptic curve.
///
@ -36,8 +34,80 @@ fn biguint_from_array(arr: [u64; 4]) -> BigUint {
])
}
impl Secp256K1Base {
fn to_canonical_biguint(&self) -> BigUint {
impl Default for Secp256K1Base {
fn default() -> Self {
Self::ZERO
}
}
impl PartialEq for Secp256K1Base {
fn eq(&self, other: &Self) -> bool {
self.to_biguint() == other.to_biguint()
}
}
impl Eq for Secp256K1Base {}
impl Hash for Secp256K1Base {
fn hash<H: Hasher>(&self, state: &mut H) {
self.to_biguint().hash(state)
}
}
impl Display for Secp256K1Base {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&self.to_biguint(), f)
}
}
impl Debug for Secp256K1Base {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Debug::fmt(&self.to_biguint(), f)
}
}
impl Field for Secp256K1Base {
const ZERO: Self = Self([0; 4]);
const ONE: Self = Self([1, 0, 0, 0]);
const TWO: Self = Self([2, 0, 0, 0]);
const NEG_ONE: Self = Self([
0xFFFFFFFEFFFFFC2E,
0xFFFFFFFFFFFFFFFF,
0xFFFFFFFFFFFFFFFF,
0xFFFFFFFFFFFFFFFF,
]);
const TWO_ADICITY: usize = 1;
const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY;
// Sage: `g = GF(p).multiplicative_generator()`
const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([5, 0, 0, 0]);
// Sage: `g_2 = g^((p - 1) / 2)`
const POWER_OF_TWO_GENERATOR: Self = Self::NEG_ONE;
const BITS: usize = 256;
fn order() -> BigUint {
BigUint::from_slice(&[
0xFFFFFC2F, 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0xFFFFFFFF,
])
}
fn characteristic() -> BigUint {
Self::order()
}
fn try_inverse(&self) -> Option<Self> {
if self.is_zero() {
return None;
}
// Fermat's Little Theorem
Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one())))
}
fn to_biguint(&self) -> BigUint {
let mut result = biguint_from_array(self.0);
if result >= Self::order() {
result -= Self::order();
@ -55,79 +125,6 @@ impl Secp256K1Base {
.expect("error converting to u64 array"),
)
}
}
impl Default for Secp256K1Base {
fn default() -> Self {
Self::ZERO
}
}
impl PartialEq for Secp256K1Base {
fn eq(&self, other: &Self) -> bool {
self.to_canonical_biguint() == other.to_canonical_biguint()
}
}
impl Eq for Secp256K1Base {}
impl Hash for Secp256K1Base {
fn hash<H: Hasher>(&self, state: &mut H) {
self.to_canonical_biguint().hash(state)
}
}
impl Display for Secp256K1Base {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&self.to_canonical_biguint(), f)
}
}
impl Debug for Secp256K1Base {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Debug::fmt(&self.to_canonical_biguint(), f)
}
}
impl Field for Secp256K1Base {
// TODO: fix
type PrimeField = GoldilocksField;
const ZERO: Self = Self([0; 4]);
const ONE: Self = Self([1, 0, 0, 0]);
const TWO: Self = Self([2, 0, 0, 0]);
const NEG_ONE: Self = Self([
0xFFFFFFFEFFFFFC2E,
0xFFFFFFFFFFFFFFFF,
0xFFFFFFFFFFFFFFFF,
0xFFFFFFFFFFFFFFFF,
]);
// TODO: fix
const CHARACTERISTIC: u64 = 0;
const TWO_ADICITY: usize = 1;
// Sage: `g = GF(p).multiplicative_generator()`
const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([5, 0, 0, 0]);
// Sage: `g_2 = g^((p - 1) / 2)`
const POWER_OF_TWO_GENERATOR: Self = Self::NEG_ONE;
fn order() -> BigUint {
BigUint::from_slice(&[
0xFFFFFC2F, 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
0xFFFFFFFF,
])
}
fn try_inverse(&self) -> Option<Self> {
if self.is_zero() {
return None;
}
// Fermat's Little Theorem
Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one())))
}
#[inline]
fn from_canonical_u64(n: u64) -> Self {
@ -157,7 +154,7 @@ impl Neg for Secp256K1Base {
if self.is_zero() {
Self::ZERO
} else {
Self::from_biguint(Self::order() - self.to_canonical_biguint())
Self::from_biguint(Self::order() - self.to_biguint())
}
}
}
@ -167,7 +164,7 @@ impl Add for Secp256K1Base {
#[inline]
fn add(self, rhs: Self) -> Self {
let mut result = self.to_canonical_biguint() + rhs.to_canonical_biguint();
let mut result = self.to_biguint() + rhs.to_biguint();
if result >= Self::order() {
result -= Self::order();
}
@ -210,9 +207,7 @@ impl Mul for Secp256K1Base {
#[inline]
fn mul(self, rhs: Self) -> Self {
Self::from_biguint(
(self.to_canonical_biguint() * rhs.to_canonical_biguint()).mod_floor(&Self::order()),
)
Self::from_biguint((self.to_biguint() * rhs.to_biguint()).mod_floor(&Self::order()))
}
}
@ -244,3 +239,10 @@ impl DivAssign for Secp256K1Base {
*self = *self / rhs;
}
}
#[cfg(test)]
mod tests {
use crate::test_field_arithmetic;
test_field_arithmetic!(crate::field::secp256k1_base::Secp256K1Base);
}

View File

@ -0,0 +1,257 @@
use std::convert::TryInto;
use std::fmt;
use std::fmt::{Debug, Display, Formatter};
use std::hash::{Hash, Hasher};
use std::iter::{Product, Sum};
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use itertools::Itertools;
use num::bigint::{BigUint, RandBigInt};
use num::{Integer, One};
use rand::Rng;
use serde::{Deserialize, Serialize};
use crate::field::field_types::Field;
/// The base field of the secp256k1 elliptic curve.
///
/// Its order is
/// ```ignore
/// P = 0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE BAAEDCE6 AF48A03B BFD25E8C D0364141
/// = 115792089237316195423570985008687907852837564279074904382605163141518161494337
/// = 2**256 - 432420386565659656852420866394968145599
/// ```
#[derive(Copy, Clone, Serialize, Deserialize)]
pub struct Secp256K1Scalar(pub [u64; 4]);
fn biguint_from_array(arr: [u64; 4]) -> BigUint {
BigUint::from_slice(&[
arr[0] as u32,
(arr[0] >> 32) as u32,
arr[1] as u32,
(arr[1] >> 32) as u32,
arr[2] as u32,
(arr[2] >> 32) as u32,
arr[3] as u32,
(arr[3] >> 32) as u32,
])
}
impl Default for Secp256K1Scalar {
fn default() -> Self {
Self::ZERO
}
}
impl PartialEq for Secp256K1Scalar {
fn eq(&self, other: &Self) -> bool {
self.to_biguint() == other.to_biguint()
}
}
impl Eq for Secp256K1Scalar {}
impl Hash for Secp256K1Scalar {
fn hash<H: Hasher>(&self, state: &mut H) {
self.to_biguint().hash(state)
}
}
impl Display for Secp256K1Scalar {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&self.to_biguint(), f)
}
}
impl Debug for Secp256K1Scalar {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Debug::fmt(&self.to_biguint(), f)
}
}
impl Field for Secp256K1Scalar {
const ZERO: Self = Self([0; 4]);
const ONE: Self = Self([1, 0, 0, 0]);
const TWO: Self = Self([2, 0, 0, 0]);
const NEG_ONE: Self = Self([
0xBFD25E8CD0364140,
0xBAAEDCE6AF48A03B,
0xFFFFFFFFFFFFFFFE,
0xFFFFFFFFFFFFFFFF,
]);
const TWO_ADICITY: usize = 6;
const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY;
// Sage: `g = GF(p).multiplicative_generator()`
const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([7, 0, 0, 0]);
// Sage: `g_2 = power_mod(g, (p - 1) // 2^6), p)`
// 5480320495727936603795231718619559942670027629901634955707709633242980176626
const POWER_OF_TWO_GENERATOR: Self = Self([
0x992f4b5402b052f2,
0x98BDEAB680756045,
0xDF9879A3FBC483A8,
0xC1DC060E7A91986,
]);
const BITS: usize = 256;
fn order() -> BigUint {
BigUint::from_slice(&[
0xD0364141, 0xBFD25E8C, 0xAF48A03B, 0xBAAEDCE6, 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFFFF,
0xFFFFFFFF,
])
}
fn characteristic() -> BigUint {
Self::order()
}
fn try_inverse(&self) -> Option<Self> {
if self.is_zero() {
return None;
}
// Fermat's Little Theorem
Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one())))
}
fn to_biguint(&self) -> BigUint {
let mut result = biguint_from_array(self.0);
if result >= Self::order() {
result -= Self::order();
}
result
}
fn from_biguint(val: BigUint) -> Self {
Self(
val.to_u64_digits()
.into_iter()
.pad_using(4, |_| 0)
.collect::<Vec<_>>()[..]
.try_into()
.expect("error converting to u64 array"),
)
}
#[inline]
fn from_canonical_u64(n: u64) -> Self {
Self([n, 0, 0, 0])
}
#[inline]
fn from_noncanonical_u128(n: u128) -> Self {
Self([n as u64, (n >> 64) as u64, 0, 0])
}
#[inline]
fn from_noncanonical_u96(n: (u64, u32)) -> Self {
Self([n.0, n.1 as u64, 0, 0])
}
fn rand_from_rng<R: Rng>(rng: &mut R) -> Self {
Self::from_biguint(rng.gen_biguint_below(&Self::order()))
}
}
impl Neg for Secp256K1Scalar {
type Output = Self;
#[inline]
fn neg(self) -> Self {
if self.is_zero() {
Self::ZERO
} else {
Self::from_biguint(Self::order() - self.to_biguint())
}
}
}
impl Add for Secp256K1Scalar {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self {
let mut result = self.to_biguint() + rhs.to_biguint();
if result >= Self::order() {
result -= Self::order();
}
Self::from_biguint(result)
}
}
impl AddAssign for Secp256K1Scalar {
#[inline]
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl Sum for Secp256K1Scalar {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Self::ZERO, |acc, x| acc + x)
}
}
impl Sub for Secp256K1Scalar {
type Output = Self;
#[inline]
#[allow(clippy::suspicious_arithmetic_impl)]
fn sub(self, rhs: Self) -> Self {
self + -rhs
}
}
impl SubAssign for Secp256K1Scalar {
#[inline]
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl Mul for Secp256K1Scalar {
type Output = Self;
#[inline]
fn mul(self, rhs: Self) -> Self {
Self::from_biguint((self.to_biguint() * rhs.to_biguint()).mod_floor(&Self::order()))
}
}
impl MulAssign for Secp256K1Scalar {
#[inline]
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl Product for Secp256K1Scalar {
#[inline]
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.reduce(|acc, x| acc * x).unwrap_or(Self::ONE)
}
}
impl Div for Secp256K1Scalar {
type Output = Self;
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, rhs: Self) -> Self::Output {
self * rhs.inverse()
}
}
impl DivAssign for Secp256K1Scalar {
fn div_assign(&mut self, rhs: Self) {
*self = *self / rhs;
}
}
#[cfg(test)]
mod tests {
use crate::test_field_arithmetic;
test_field_arithmetic!(crate::field::secp256k1_scalar::Secp256K1Scalar);
}

View File

@ -11,14 +11,14 @@ use crate::plonk::circuit_data::CommonCircuitData;
use crate::plonk::config::GenericConfig;
use crate::plonk::plonk_common::PlonkPolynomials;
use crate::plonk::proof::OpeningSet;
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::timed;
use crate::util::reducing::ReducingFactor;
use crate::util::timing::TimingTree;
use crate::util::{log2_strict, reverse_bits, reverse_index_bits_in_place, transpose};
/// Two (~64 bit) field elements gives ~128 bit security.
pub const SALT_SIZE: usize = 2;
/// Four (~64 bit) field elements gives ~128 bit security.
pub const SALT_SIZE: usize = 4;
/// Represents a batch FRI based commitment to a list of polynomials.
pub struct PolynomialBatchCommitment<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize> {

View File

@ -36,8 +36,4 @@ impl FriParams {
pub(crate) fn max_arity_bits(&self) -> Option<usize> {
self.reduction_arity_bits.iter().copied().max()
}
pub(crate) fn max_arity(&self) -> Option<usize> {
self.max_arity_bits().map(|bits| 1 << bits)
}
}

View File

@ -16,7 +16,7 @@ use crate::plonk::circuit_data::CommonCircuitData;
use crate::plonk::config::{GenericConfig, Hasher};
use crate::plonk::plonk_common::PolynomialsIndexBlinding;
use crate::plonk::proof::{FriInferredElements, ProofChallenges};
use crate::polynomial::polynomial::PolynomialCoeffs;
use crate::polynomial::PolynomialCoeffs;
/// Evaluations and Merkle proof produced by the prover in a FRI query step.
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]

View File

@ -9,7 +9,7 @@ use crate::iop::challenger::Challenger;
use crate::plonk::circuit_data::CommonCircuitData;
use crate::plonk::config::{GenericConfig, Hasher};
use crate::plonk::plonk_common::reduce_with_powers;
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::timed;
use crate::util::reverse_index_bits_in_place;
use crate::util::timing::TimingTree;

View File

@ -3,13 +3,16 @@ use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, RichField};
use crate::fri::proof::{FriInitialTreeProofTarget, FriProofTarget, FriQueryRoundTarget};
use crate::fri::FriConfig;
use crate::gadgets::interpolation::InterpolationGate;
use crate::gates::gate::Gate;
use crate::gates::interpolation::InterpolationGate;
use crate::gates::interpolation::HighDegreeInterpolationGate;
use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate;
use crate::gates::random_access::RandomAccessGate;
use crate::hash::hash_types::MerkleCapTarget;
use crate::iop::challenger::RecursiveChallenger;
use crate::iop::target::{BoolTarget, Target};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData};
use crate::plonk::circuit_data::CommonCircuitData;
use crate::plonk::config::{AlgebraicConfig, AlgebraicHasher, GenericConfig};
use crate::plonk::plonk_common::PlonkPolynomials;
@ -28,6 +31,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
arity_bits: usize,
evals: &[ExtensionTarget<D>],
beta: ExtensionTarget<D>,
common_data: &CommonCircuitData<F, D>,
) -> ExtensionTarget<D> {
let arity = 1 << arity_bits;
debug_assert_eq!(evals.len(), arity);
@ -44,37 +48,62 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let coset_start = self.mul(start, x);
// The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta.
self.interpolate_coset(arity_bits, coset_start, &evals, beta)
// `HighDegreeInterpolationGate` has degree `arity`, so we use the low-degree gate if
// the arity is too large.
if arity > common_data.quotient_degree_factor {
self.interpolate_coset::<LowDegreeInterpolationGate<F, D>>(
arity_bits,
coset_start,
&evals,
beta,
)
} else {
self.interpolate_coset::<HighDegreeInterpolationGate<F, D>>(
arity_bits,
coset_start,
&evals,
beta,
)
}
}
/// Make sure we have enough wires and routed wires to do the FRI checks efficiently. This check
/// isn't required -- without it we'd get errors elsewhere in the stack -- but just gives more
/// helpful errors.
fn check_recursion_config(&self, max_fri_arity: usize) {
fn check_recursion_config(
&self,
max_fri_arity_bits: usize,
common_data: &CommonCircuitData<F, D>,
) {
let random_access = RandomAccessGate::<F, D>::new_from_config(
&self.config,
max_fri_arity.max(1 << self.config.cap_height),
max_fri_arity_bits.max(self.config.cap_height),
);
let interpolation_gate = InterpolationGate::<F, D>::new(log2_strict(max_fri_arity));
let (interpolation_wires, interpolation_routed_wires) =
if 1 << max_fri_arity_bits > common_data.quotient_degree_factor {
let gate = LowDegreeInterpolationGate::<F, D>::new(max_fri_arity_bits);
(gate.num_wires(), gate.num_routed_wires())
} else {
let gate = HighDegreeInterpolationGate::<F, D>::new(max_fri_arity_bits);
(gate.num_wires(), gate.num_routed_wires())
};
let min_wires = random_access
.num_wires()
.max(interpolation_gate.num_wires());
let min_wires = random_access.num_wires().max(interpolation_wires);
let min_routed_wires = random_access
.num_routed_wires()
.max(interpolation_gate.num_routed_wires());
.max(interpolation_routed_wires);
assert!(
self.config.num_wires >= min_wires,
"To efficiently perform FRI checks with an arity of {}, at least {} wires are needed. Consider reducing arity.",
max_fri_arity,
"To efficiently perform FRI checks with an arity of 2^{}, at least {} wires are needed. Consider reducing arity.",
max_fri_arity_bits,
min_wires
);
assert!(
self.config.num_routed_wires >= min_routed_wires,
"To efficiently perform FRI checks with an arity of {}, at least {} routed wires are needed. Consider reducing arity.",
max_fri_arity,
"To efficiently perform FRI checks with an arity of 2^{}, at least {} routed wires are needed. Consider reducing arity.",
max_fri_arity_bits,
min_routed_wires
);
}
@ -108,8 +137,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
) {
let config = &common_data.config;
if let Some(max_arity) = common_data.fri_params.max_arity() {
self.check_recursion_config(max_arity);
if let Some(max_arity_bits) = common_data.fri_params.max_arity_bits() {
self.check_recursion_config(max_arity_bits, common_data);
}
debug_assert_eq!(
@ -233,7 +262,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
common_data: &CommonCircuitData<F, C, D>,
) -> ExtensionTarget<D> {
assert!(D > 1, "Not implemented for D=1.");
let config = self.config.clone();
let config = &common_data.config;
let degree_log = common_data.degree_bits;
debug_assert_eq!(
degree_log,
@ -306,9 +335,13 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
common_data: &CommonCircuitData<F, C, D>,
) {
let n_log = log2_strict(n);
// TODO: Do we need to range check `x_index` to a target smaller than `p`?
// Note that this `low_bits` decomposition permits non-canonical binary encodings. Here we
// verify that this has a negligible impact on soundness error.
Self::assert_noncanonical_indices_ok(&common_data.config);
let x_index = challenger.get_challenge(self);
let mut x_index_bits = self.low_bits(x_index, n_log, 64);
let mut x_index_bits = self.low_bits(x_index, n_log, F::BITS);
let cap_index =
self.le_sum(x_index_bits[x_index_bits.len() - common_data.config.cap_height..].iter());
with_context!(
@ -376,6 +409,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
arity_bits,
evals,
betas[i],
common_data
)
);
@ -409,6 +443,26 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
);
self.connect_extension(eval, old_eval);
}
/// We decompose FRI query indices into bits without verifying that the decomposition given by
/// the prover is the canonical one. In particular, if `x_index < 2^field_bits - p`, then the
/// prover could supply the binary encoding of either `x_index` or `x_index + p`, since the are
/// congruent mod `p`. However, this only occurs with probability
/// p_ambiguous = (2^field_bits - p) / p
/// which is small for the field that we use in practice.
///
/// In particular, the soundness error of one FRI query is roughly the codeword rate, which
/// is much larger than this ambiguous-element probability given any reasonable parameters.
/// Thus ambiguous elements contribute a negligible amount to soundness error.
///
/// Here we compare the probabilities as a sanity check, to verify the claim above.
fn assert_noncanonical_indices_ok(config: &CircuitConfig) {
let num_ambiguous_elems = u64::MAX - F::ORDER + 1;
let query_error = config.rate();
let p_ambiguous = (num_ambiguous_elems as f64) / (F::ORDER as f64);
assert!(p_ambiguous < query_error * 1e-5,
"A non-negligible portion of field elements are in the range that permits non-canonical encodings. Need to do more analysis or enforce canonical encodings.");
}
}
#[derive(Copy, Clone)]

View File

@ -2,6 +2,8 @@ use std::borrow::Borrow;
use crate::field::extension_field::Extendable;
use crate::gates::arithmetic::ArithmeticExtensionGate;
use crate::field::field_types::{PrimeField, RichField};
use crate::gates::arithmetic_base::ArithmeticGate;
use crate::gates::exponentiation::ExponentiationGate;
use crate::iop::target::{BoolTarget, Target};
use crate::plonk::circuit_builder::CircuitBuilder;
@ -32,18 +34,117 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
multiplicand_1: Target,
addend: Target,
) -> Target {
let multiplicand_0_ext = self.convert_to_ext(multiplicand_0);
let multiplicand_1_ext = self.convert_to_ext(multiplicand_1);
let addend_ext = self.convert_to_ext(addend);
// If we're not configured to use the base arithmetic gate, just call arithmetic_extension.
if !self.config.use_base_arithmetic_gate {
let multiplicand_0_ext = self.convert_to_ext(multiplicand_0);
let multiplicand_1_ext = self.convert_to_ext(multiplicand_1);
let addend_ext = self.convert_to_ext(addend);
self.arithmetic_extension(
return self
.arithmetic_extension(
const_0,
const_1,
multiplicand_0_ext,
multiplicand_1_ext,
addend_ext,
)
.0[0];
}
// See if we can determine the result without adding an `ArithmeticGate`.
if let Some(result) =
self.arithmetic_special_cases(const_0, const_1, multiplicand_0, multiplicand_1, addend)
{
return result;
}
// See if we've already computed the same operation.
let operation = BaseArithmeticOperation {
const_0,
const_1,
multiplicand_0_ext,
multiplicand_1_ext,
addend_ext,
)
.0[0]
multiplicand_0,
multiplicand_1,
addend,
};
if let Some(&result) = self.base_arithmetic_results.get(&operation) {
return result;
}
// Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot.
let result = self.add_base_arithmetic_operation(operation);
self.base_arithmetic_results.insert(operation, result);
result
}
fn add_base_arithmetic_operation(&mut self, operation: BaseArithmeticOperation<F>) -> Target {
let (gate, i) = self.find_base_arithmetic_gate(operation.const_0, operation.const_1);
let wires_multiplicand_0 = Target::wire(gate, ArithmeticGate::wire_ith_multiplicand_0(i));
let wires_multiplicand_1 = Target::wire(gate, ArithmeticGate::wire_ith_multiplicand_1(i));
let wires_addend = Target::wire(gate, ArithmeticGate::wire_ith_addend(i));
self.connect(operation.multiplicand_0, wires_multiplicand_0);
self.connect(operation.multiplicand_1, wires_multiplicand_1);
self.connect(operation.addend, wires_addend);
Target::wire(gate, ArithmeticGate::wire_ith_output(i))
}
/// Checks for special cases where the value of
/// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`
/// can be determined without adding an `ArithmeticGate`.
fn arithmetic_special_cases(
&mut self,
const_0: F,
const_1: F,
multiplicand_0: Target,
multiplicand_1: Target,
addend: Target,
) -> Option<Target> {
let zero = self.zero();
let mul_0_const = self.target_as_constant(multiplicand_0);
let mul_1_const = self.target_as_constant(multiplicand_1);
let addend_const = self.target_as_constant(addend);
let first_term_zero =
const_0 == F::ZERO || multiplicand_0 == zero || multiplicand_1 == zero;
let second_term_zero = const_1 == F::ZERO || addend == zero;
// If both terms are constant, return their (constant) sum.
let first_term_const = if first_term_zero {
Some(F::ZERO)
} else if let (Some(x), Some(y)) = (mul_0_const, mul_1_const) {
Some(x * y * const_0)
} else {
None
};
let second_term_const = if second_term_zero {
Some(F::ZERO)
} else {
addend_const.map(|x| x * const_1)
};
if let (Some(x), Some(y)) = (first_term_const, second_term_const) {
return Some(self.constant(x + y));
}
if first_term_zero && const_1.is_one() {
return Some(addend);
}
if second_term_zero {
if let Some(x) = mul_0_const {
if (x * const_0).is_one() {
return Some(multiplicand_1);
}
}
if let Some(x) = mul_1_const {
if (x * const_0).is_one() {
return Some(multiplicand_0);
}
}
}
None
}
/// Computes `x * y + z`.
@ -53,20 +154,20 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Computes `x + C`.
pub fn add_const(&mut self, x: Target, c: F) -> Target {
let one = self.one();
self.arithmetic(F::ONE, c, one, x, one)
let c = self.constant(c);
self.add(x, c)
}
/// Computes `C * x`.
pub fn mul_const(&mut self, c: F, x: Target) -> Target {
let zero = self.zero();
self.mul_const_add(c, x, zero)
let c = self.constant(c);
self.mul(c, x)
}
/// Computes `C * x + y`.
pub fn mul_const_add(&mut self, c: F, x: Target, y: Target) -> Target {
let one = self.one();
self.arithmetic(c, F::ONE, x, one, y)
let c = self.constant(c);
self.mul_add(c, x, y)
}
/// Computes `x * y - z`.
@ -82,13 +183,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
/// Add `n` `Target`s.
// TODO: Can be made `D` times more efficient by using all wires of an `ArithmeticExtensionGate`.
pub fn add_many(&mut self, terms: &[Target]) -> Target {
let terms_ext = terms
.iter()
.map(|&t| self.convert_to_ext(t))
.collect::<Vec<_>>();
self.add_many_extension(&terms_ext).to_target_array()[0]
terms.iter().fold(self.zero(), |acc, &t| self.add(acc, t))
}
/// Computes `x - y`.
@ -106,16 +202,16 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Multiply `n` `Target`s.
pub fn mul_many(&mut self, terms: &[Target]) -> Target {
let terms_ext = terms
terms
.iter()
.map(|&t| self.convert_to_ext(t))
.collect::<Vec<_>>();
self.mul_many_extension(&terms_ext).to_target_array()[0]
.copied()
.reduce(|acc, t| self.mul(acc, t))
.unwrap_or_else(|| self.one())
}
/// Exponentiate `base` to the power of `2^power_log`.
pub fn exp_power_of_2(&mut self, base: Target, power_log: usize) -> Target {
if power_log > ArithmeticExtensionGate::<D>::new_from_config(&self.config).num_ops {
if power_log > self.num_base_arithmetic_ops_per_gate() {
// Cheaper to just use `ExponentiateGate`.
return self.exp_u64(base, 1 << power_log);
}
@ -169,8 +265,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let base_t = self.constant(base);
let exponent_bits: Vec<_> = exponent_bits.into_iter().map(|b| *b.borrow()).collect();
if exponent_bits.len() > ArithmeticExtensionGate::<D>::new_from_config(&self.config).num_ops
{
if exponent_bits.len() > self.num_base_arithmetic_ops_per_gate() {
// Cheaper to just use `ExponentiateGate`.
return self.exp_from_bits(base_t, exponent_bits);
}
@ -220,3 +315,13 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.inverse_extension(x_ext).0[0]
}
}
/// Represents a base arithmetic operation in the circuit. Used to memoize results.
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
pub(crate) struct BaseArithmeticOperation<F: PrimeField> {
const_0: F,
const_1: F,
multiplicand_0: Target,
multiplicand_1: Target,
addend: Target,
}

View File

@ -1,8 +1,9 @@
use std::convert::TryInto;
use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget};
use crate::field::extension_field::FieldExtension;
use crate::field::extension_field::{Extendable, OEF};
use crate::field::field_types::{Field, PrimeField, RichField};
use crate::gates::arithmetic_extension::ArithmeticExtensionGate;
use crate::gates::multiplication_extension::MulExtensionGate;
use crate::field::field_types::{Field, PrimeField};
use crate::gates::arithmetic::ArithmeticExtensionGate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
@ -12,33 +13,6 @@ use crate::plonk::circuit_builder::CircuitBuilder;
use crate::util::bits_u64;
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Finds the last available arithmetic gate with the given constants or add one if there aren't any.
/// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index
/// `g` and the gate's `i`-th operation is available.
fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) {
let (gate, i) = self
.free_arithmetic
.get(&(const_0, const_1))
.copied()
.unwrap_or_else(|| {
let gate = self.add_gate(
ArithmeticExtensionGate::new_from_config(&self.config),
vec![const_0, const_1],
);
(gate, 0)
});
// Update `free_arithmetic` with new values.
if i < ArithmeticExtensionGate::<D>::num_ops(&self.config) - 1 {
self.free_arithmetic
.insert((const_0, const_1), (gate, i + 1));
} else {
self.free_arithmetic.remove(&(const_0, const_1));
}
(gate, i)
}
pub fn arithmetic_extension(
&mut self,
const_0: F,
@ -59,7 +33,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
// See if we've already computed the same operation.
let operation = ArithmeticOperation {
let operation = ExtensionArithmeticOperation {
const_0,
const_1,
multiplicand_0,
@ -70,15 +44,21 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
return result;
}
let result = if self.target_as_constant_ext(addend) == Some(F::Extension::ZERO) {
// If the addend is zero, we use a multiplication gate.
self.compute_mul_extension_operation(operation)
} else {
// Otherwise, we use an arithmetic gate.
self.compute_arithmetic_extension_operation(operation)
};
// Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot.
let result = self.add_arithmetic_extension_operation(operation);
self.arithmetic_results.insert(operation, result);
result
}
fn add_arithmetic_extension_operation(
fn compute_arithmetic_extension_operation(
&mut self,
operation: ArithmeticOperation<F, D>,
operation: ExtensionArithmeticOperation<F, D>,
) -> ExtensionTarget<D> {
let (gate, i) = self.find_arithmetic_gate(operation.const_0, operation.const_1);
let wires_multiplicand_0 = ExtensionTarget::from_range(
@ -99,6 +79,22 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_ith_output(i))
}
fn compute_mul_extension_operation(
&mut self,
operation: ExtensionArithmeticOperation<F, D>,
) -> ExtensionTarget<D> {
let (gate, i) = self.find_mul_gate(operation.const_0);
let wires_multiplicand_0 =
ExtensionTarget::from_range(gate, MulExtensionGate::<D>::wires_ith_multiplicand_0(i));
let wires_multiplicand_1 =
ExtensionTarget::from_range(gate, MulExtensionGate::<D>::wires_ith_multiplicand_1(i));
self.connect_extension(operation.multiplicand_0, wires_multiplicand_0);
self.connect_extension(operation.multiplicand_1, wires_multiplicand_1);
ExtensionTarget::from_range(gate, MulExtensionGate::<D>::wires_ith_output(i))
}
/// Checks for special cases where the value of
/// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`
/// can be determined without adding an `ArithmeticGate`.
@ -302,11 +298,11 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Multiply `n` `ExtensionTarget`s.
pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> {
let mut product = self.one_extension();
for &term in terms {
product = self.mul_extension(product, term);
}
product
terms
.iter()
.copied()
.reduce(|acc, t| self.mul_extension(acc, t))
.unwrap_or_else(|| self.one_extension())
}
/// Like `mul_add`, but for `ExtensionTarget`s.
@ -321,14 +317,14 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Like `add_const`, but for `ExtensionTarget`s.
pub fn add_const_extension(&mut self, x: ExtensionTarget<D>, c: F) -> ExtensionTarget<D> {
let one = self.one_extension();
self.arithmetic_extension(F::ONE, c, one, x, one)
let c = self.constant_extension(c.into());
self.add_extension(x, c)
}
/// Like `mul_const`, but for `ExtensionTarget`s.
pub fn mul_const_extension(&mut self, c: F, x: ExtensionTarget<D>) -> ExtensionTarget<D> {
let zero = self.zero_extension();
self.mul_const_add_extension(c, x, zero)
let c = self.constant_extension(c.into());
self.mul_extension(c, x)
}
/// Like `mul_const_add`, but for `ExtensionTarget`s.
@ -338,8 +334,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
x: ExtensionTarget<D>,
y: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let one = self.one_extension();
self.arithmetic_extension(c, F::ONE, x, one, y)
let c = self.constant_extension(c.into());
self.mul_add_extension(c, x, y)
}
/// Like `mul_add`, but for `ExtensionTarget`s.
@ -544,9 +540,9 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
/// Represents an arithmetic operation in the circuit. Used to memoize results.
/// Represents an extension arithmetic operation in the circuit. Used to memoize results.
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
pub(crate) struct ArithmeticOperation<F: PrimeField + Extendable<D>, const D: usize> {
pub(crate) struct ExtensionArithmeticOperation<F: PrimeField + Extendable<D>, const D: usize> {
const_0: F,
const_1: F,
multiplicand_0: ExtensionTarget<D>,
@ -556,11 +552,11 @@ pub(crate) struct ArithmeticOperation<F: PrimeField + Extendable<D>, const D: us
#[cfg(test)]
mod tests {
use std::convert::TryInto;
use anyhow::Result;
use crate::field::extension_field::algebra::ExtensionAlgebra;
use crate::field::extension_field::quartic::QuarticExtension;
use crate::field::extension_field::target::ExtensionAlgebraTarget;
use crate::field::field_types::Field;
use crate::iop::witness::{PartialWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
@ -623,9 +619,7 @@ mod tests {
let yt = builder.constant_extension(y);
let zt = builder.constant_extension(z);
let comp_zt = builder.div_extension(xt, yt);
let comp_zt_unsafe = builder.div_extension(xt, yt);
builder.connect_extension(zt, comp_zt);
builder.connect_extension(zt, comp_zt_unsafe);
let data = builder.build::<C>();
let proof = data.prove(pw)?;
@ -642,23 +636,29 @@ mod tests {
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let x = FF::rand_vec(D);
let y = FF::rand_vec(D);
let xa = ExtensionAlgebra(x.try_into().unwrap());
let ya = ExtensionAlgebra(y.try_into().unwrap());
let za = xa * ya;
let xt = builder.constant_ext_algebra(xa);
let yt = builder.constant_ext_algebra(ya);
let zt = builder.constant_ext_algebra(za);
let xt =
ExtensionAlgebraTarget(builder.add_virtual_extension_targets(D).try_into().unwrap());
let yt =
ExtensionAlgebraTarget(builder.add_virtual_extension_targets(D).try_into().unwrap());
let zt =
ExtensionAlgebraTarget(builder.add_virtual_extension_targets(D).try_into().unwrap());
let comp_zt = builder.mul_ext_algebra(xt, yt);
for i in 0..D {
builder.connect_extension(zt.0[i], comp_zt.0[i]);
}
let x = ExtensionAlgebra::<FF, D>(FF::rand_arr());
let y = ExtensionAlgebra::<FF, D>(FF::rand_arr());
let z = x * y;
for i in 0..D {
pw.set_extension_target(xt.0[i], x.0[i]);
pw.set_extension_target(yt.0[i], y.0[i]);
pw.set_extension_target(zt.0[i], z.0[i]);
}
let data = builder.build::<C>();
let proof = data.prove(pw)?;

View File

@ -0,0 +1,154 @@
use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField;
use crate::gates::arithmetic_u32::U32ArithmeticGate;
use crate::gates::subtraction_u32::U32SubtractionGate;
use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;
#[derive(Clone, Copy, Debug)]
pub struct U32Target(pub Target);
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn add_virtual_u32_target(&mut self) -> U32Target {
U32Target(self.add_virtual_target())
}
pub fn add_virtual_u32_targets(&mut self, n: usize) -> Vec<U32Target> {
self.add_virtual_targets(n)
.into_iter()
.map(U32Target)
.collect()
}
pub fn zero_u32(&mut self) -> U32Target {
U32Target(self.zero())
}
pub fn one_u32(&mut self) -> U32Target {
U32Target(self.one())
}
pub fn connect_u32(&mut self, x: U32Target, y: U32Target) {
self.connect(x.0, y.0)
}
pub fn assert_zero_u32(&mut self, x: U32Target) {
self.assert_zero(x.0)
}
/// Checks for special cases where the value of
/// `x * y + z`
/// can be determined without adding a `U32ArithmeticGate`.
pub fn arithmetic_u32_special_cases(
&mut self,
x: U32Target,
y: U32Target,
z: U32Target,
) -> Option<(U32Target, U32Target)> {
let x_const = self.target_as_constant(x.0);
let y_const = self.target_as_constant(y.0);
let z_const = self.target_as_constant(z.0);
// If both terms are constant, return their (constant) sum.
let first_term_const = if let (Some(xx), Some(yy)) = (x_const, y_const) {
Some(xx * yy)
} else {
None
};
if let (Some(a), Some(b)) = (first_term_const, z_const) {
let sum = (a + b).to_canonical_u64();
let (low, high) = (sum as u32, (sum >> 32) as u32);
return Some((self.constant_u32(low), self.constant_u32(high)));
}
None
}
// Returns x * y + z.
pub fn mul_add_u32(
&mut self,
x: U32Target,
y: U32Target,
z: U32Target,
) -> (U32Target, U32Target) {
if let Some(result) = self.arithmetic_u32_special_cases(x, y, z) {
return result;
}
let gate = U32ArithmeticGate::<F, D>::new_from_config(&self.config);
let (gate_index, copy) = self.find_u32_arithmetic_gate();
self.connect(
Target::wire(gate_index, gate.wire_ith_multiplicand_0(copy)),
x.0,
);
self.connect(
Target::wire(gate_index, gate.wire_ith_multiplicand_1(copy)),
y.0,
);
self.connect(Target::wire(gate_index, gate.wire_ith_addend(copy)), z.0);
let output_low = U32Target(Target::wire(
gate_index,
gate.wire_ith_output_low_half(copy),
));
let output_high = U32Target(Target::wire(
gate_index,
gate.wire_ith_output_high_half(copy),
));
(output_low, output_high)
}
pub fn add_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) {
let one = self.one_u32();
self.mul_add_u32(a, one, b)
}
pub fn add_many_u32(&mut self, to_add: &[U32Target]) -> (U32Target, U32Target) {
match to_add.len() {
0 => (self.zero_u32(), self.zero_u32()),
1 => (to_add[0], self.zero_u32()),
2 => self.add_u32(to_add[0], to_add[1]),
_ => {
let (mut low, mut carry) = self.add_u32(to_add[0], to_add[1]);
for i in 2..to_add.len() {
let (new_low, new_carry) = self.add_u32(to_add[i], low);
let (combined_carry, _zero) = self.add_u32(carry, new_carry);
low = new_low;
carry = combined_carry;
}
(low, carry)
}
}
}
pub fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) {
let zero = self.zero_u32();
self.mul_add_u32(a, b, zero)
}
// Returns x - y - borrow, as a pair (result, borrow), where borrow is 0 or 1 depending on whether borrowing from the next digit is required (iff y + borrow > x).
pub fn sub_u32(
&mut self,
x: U32Target,
y: U32Target,
borrow: U32Target,
) -> (U32Target, U32Target) {
let gate = U32SubtractionGate::<F, D>::new_from_config(&self.config);
let (gate_index, copy) = self.find_u32_subtraction_gate();
self.connect(Target::wire(gate_index, gate.wire_ith_input_x(copy)), x.0);
self.connect(Target::wire(gate_index, gate.wire_ith_input_y(copy)), y.0);
self.connect(
Target::wire(gate_index, gate.wire_ith_input_borrow(copy)),
borrow.0,
);
let output_result = U32Target(Target::wire(gate_index, gate.wire_ith_output_result(copy)));
let output_borrow = U32Target(Target::wire(gate_index, gate.wire_ith_output_borrow(copy)));
(output_result, output_borrow)
}
}

395
src/gadgets/biguint.rs Normal file
View File

@ -0,0 +1,395 @@
use std::marker::PhantomData;
use num::{BigUint, Integer};
use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField;
use crate::gadgets::arithmetic_u32::U32Target;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::{BoolTarget, Target};
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
#[derive(Clone, Debug)]
pub struct BigUintTarget {
pub limbs: Vec<U32Target>,
}
impl BigUintTarget {
pub fn num_limbs(&self) -> usize {
self.limbs.len()
}
pub fn get_limb(&self, i: usize) -> U32Target {
self.limbs[i]
}
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget {
let limb_values = value.to_u32_digits();
let limbs = limb_values.iter().map(|&l| self.constant_u32(l)).collect();
BigUintTarget { limbs }
}
pub fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) {
let min_limbs = lhs.num_limbs().min(rhs.num_limbs());
for i in 0..min_limbs {
self.connect_u32(lhs.get_limb(i), rhs.get_limb(i));
}
for i in min_limbs..lhs.num_limbs() {
self.assert_zero_u32(lhs.get_limb(i));
}
for i in min_limbs..rhs.num_limbs() {
self.assert_zero_u32(rhs.get_limb(i));
}
}
pub fn pad_biguints(
&mut self,
a: &BigUintTarget,
b: &BigUintTarget,
) -> (BigUintTarget, BigUintTarget) {
if a.num_limbs() > b.num_limbs() {
let mut padded_b = b.clone();
for _ in b.num_limbs()..a.num_limbs() {
padded_b.limbs.push(self.zero_u32());
}
(a.clone(), padded_b)
} else {
let mut padded_a = a.clone();
for _ in a.num_limbs()..b.num_limbs() {
padded_a.limbs.push(self.zero_u32());
}
(padded_a, b.clone())
}
}
pub fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget {
let (a, b) = self.pad_biguints(a, b);
self.list_le_u32(a.limbs, b.limbs)
}
pub fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget {
let limbs = (0..num_limbs)
.map(|_| self.add_virtual_u32_target())
.collect();
BigUintTarget { limbs }
}
// Add two `BigUintTarget`s.
pub fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let num_limbs = a.num_limbs().max(b.num_limbs());
let mut combined_limbs = vec![];
let mut carry = self.zero_u32();
for i in 0..num_limbs {
let a_limb = (i < a.num_limbs())
.then(|| a.limbs[i])
.unwrap_or_else(|| self.zero_u32());
let b_limb = (i < b.num_limbs())
.then(|| b.limbs[i])
.unwrap_or_else(|| self.zero_u32());
let (new_limb, new_carry) = self.add_many_u32(&[carry, a_limb, b_limb]);
carry = new_carry;
combined_limbs.push(new_limb);
}
combined_limbs.push(carry);
BigUintTarget {
limbs: combined_limbs,
}
}
// Subtract two `BigUintTarget`s. We assume that the first is larger than the second.
pub fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let (a, b) = self.pad_biguints(a, b);
let num_limbs = a.limbs.len();
let mut result_limbs = vec![];
let mut borrow = self.zero_u32();
for i in 0..num_limbs {
let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow);
result_limbs.push(result);
borrow = new_borrow;
}
// Borrow should be zero here.
BigUintTarget {
limbs: result_limbs,
}
}
pub fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let total_limbs = a.limbs.len() + b.limbs.len();
let mut to_add = vec![vec![]; total_limbs];
for i in 0..a.limbs.len() {
for j in 0..b.limbs.len() {
let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]);
to_add[i + j].push(product);
to_add[i + j + 1].push(carry);
}
}
let mut combined_limbs = vec![];
let mut carry = self.zero_u32();
for summands in &mut to_add {
summands.push(carry);
let (new_result, new_carry) = self.add_many_u32(summands);
combined_limbs.push(new_result);
carry = new_carry;
}
combined_limbs.push(carry);
BigUintTarget {
limbs: combined_limbs,
}
}
// Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function).
pub fn mul_add_biguint(
&mut self,
x: &BigUintTarget,
y: &BigUintTarget,
z: &BigUintTarget,
) -> BigUintTarget {
let prod = self.mul_biguint(x, y);
self.add_biguint(&prod, z)
}
pub fn div_rem_biguint(
&mut self,
a: &BigUintTarget,
b: &BigUintTarget,
) -> (BigUintTarget, BigUintTarget) {
let a_len = a.limbs.len();
let b_len = b.limbs.len();
let div_num_limbs = if b_len > a_len + 1 {
0
} else {
a_len - b_len + 1
};
let div = self.add_virtual_biguint_target(div_num_limbs);
let rem = self.add_virtual_biguint_target(b_len);
self.add_simple_generator(BigUintDivRemGenerator::<F, D> {
a: a.clone(),
b: b.clone(),
div: div.clone(),
rem: rem.clone(),
_phantom: PhantomData,
});
let div_b = self.mul_biguint(&div, b);
let div_b_plus_rem = self.add_biguint(&div_b, &rem);
self.connect_biguint(a, &div_b_plus_rem);
let cmp_rem_b = self.cmp_biguint(&rem, b);
self.assert_one(cmp_rem_b.target);
(div, rem)
}
pub fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let (div, _rem) = self.div_rem_biguint(a, b);
div
}
pub fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let (_div, rem) = self.div_rem_biguint(a, b);
rem
}
}
#[derive(Debug)]
struct BigUintDivRemGenerator<F: RichField + Extendable<D>, const D: usize> {
a: BigUintTarget,
b: BigUintTarget,
div: BigUintTarget,
rem: BigUintTarget,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
for BigUintDivRemGenerator<F, D>
{
fn dependencies(&self) -> Vec<Target> {
self.a
.limbs
.iter()
.chain(&self.b.limbs)
.map(|&l| l.0)
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let a = witness.get_biguint_target(self.a.clone());
let b = witness.get_biguint_target(self.b.clone());
let (div, rem) = a.div_rem(&b);
out_buffer.set_biguint_target(self.div.clone(), div);
out_buffer.set_biguint_target(self.rem.clone(), rem);
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use num::{BigUint, FromPrimitive, Integer};
use rand::Rng;
use crate::iop::witness::Witness;
use crate::{
field::goldilocks_field::GoldilocksField,
iop::witness::PartialWitness,
plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, verifier::verify},
};
#[test]
fn test_biguint_add() -> Result<()> {
let mut rng = rand::thread_rng();
let x_value = BigUint::from_u128(rng.gen()).unwrap();
let y_value = BigUint::from_u128(rng.gen()).unwrap();
let expected_z_value = &x_value + &y_value;
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let mut pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len());
let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len());
let z = builder.add_biguint(&x, &y);
let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len());
builder.connect_biguint(&z, &expected_z);
pw.set_biguint_target(&x, &x_value);
pw.set_biguint_target(&y, &y_value);
pw.set_biguint_target(&expected_z, &expected_z_value);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_biguint_sub() -> Result<()> {
let mut rng = rand::thread_rng();
let mut x_value = BigUint::from_u128(rng.gen()).unwrap();
let mut y_value = BigUint::from_u128(rng.gen()).unwrap();
if y_value > x_value {
(x_value, y_value) = (y_value, x_value);
}
let expected_z_value = &x_value - &y_value;
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_biguint(&x_value);
let y = builder.constant_biguint(&y_value);
let z = builder.sub_biguint(&x, &y);
let expected_z = builder.constant_biguint(&expected_z_value);
builder.connect_biguint(&z, &expected_z);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_biguint_mul() -> Result<()> {
let mut rng = rand::thread_rng();
let x_value = BigUint::from_u128(rng.gen()).unwrap();
let y_value = BigUint::from_u128(rng.gen()).unwrap();
let expected_z_value = &x_value * &y_value;
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let mut pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len());
let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len());
let z = builder.mul_biguint(&x, &y);
let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len());
builder.connect_biguint(&z, &expected_z);
pw.set_biguint_target(&x, &x_value);
pw.set_biguint_target(&y, &y_value);
pw.set_biguint_target(&expected_z, &expected_z_value);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_biguint_cmp() -> Result<()> {
let mut rng = rand::thread_rng();
let x_value = BigUint::from_u128(rng.gen()).unwrap();
let y_value = BigUint::from_u128(rng.gen()).unwrap();
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_biguint(&x_value);
let y = builder.constant_biguint(&y_value);
let cmp = builder.cmp_biguint(&x, &y);
let expected_cmp = builder.constant_bool(x_value <= y_value);
builder.connect(cmp.target, expected_cmp.target);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_biguint_div_rem() -> Result<()> {
let mut rng = rand::thread_rng();
let mut x_value = BigUint::from_u128(rng.gen()).unwrap();
let mut y_value = BigUint::from_u128(rng.gen()).unwrap();
if y_value > x_value {
(x_value, y_value) = (y_value, x_value);
}
let (expected_div_value, expected_rem_value) = x_value.div_rem(&y_value);
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_biguint(&x_value);
let y = builder.constant_biguint(&y_value);
let (div, rem) = builder.div_rem_biguint(&x, &y);
let expected_div = builder.constant_biguint(&expected_div_value);
let expected_rem = builder.constant_biguint(&expected_rem_value);
builder.connect_biguint(&div, &expected_div);
builder.connect_biguint(&rem, &expected_rem);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
}

368
src/gadgets/curve.rs Normal file
View File

@ -0,0 +1,368 @@
use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar};
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, RichField};
use crate::gadgets::nonnative::NonNativeTarget;
use crate::plonk::circuit_builder::CircuitBuilder;
/// A Target representing an affine point on the curve `C`. We use incomplete arithmetic for efficiency,
/// so we assume these points are not zero.
#[derive(Clone, Debug)]
pub struct AffinePointTarget<C: Curve> {
pub x: NonNativeTarget<C::BaseField>,
pub y: NonNativeTarget<C::BaseField>,
}
impl<C: Curve> AffinePointTarget<C> {
pub fn to_vec(&self) -> Vec<NonNativeTarget<C::BaseField>> {
vec![self.x.clone(), self.y.clone()]
}
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn constant_affine_point<C: Curve>(
&mut self,
point: AffinePoint<C>,
) -> AffinePointTarget<C> {
debug_assert!(!point.zero);
AffinePointTarget {
x: self.constant_nonnative(point.x),
y: self.constant_nonnative(point.y),
}
}
pub fn connect_affine_point<C: Curve>(
&mut self,
lhs: &AffinePointTarget<C>,
rhs: &AffinePointTarget<C>,
) {
self.connect_nonnative(&lhs.x, &rhs.x);
self.connect_nonnative(&lhs.y, &rhs.y);
}
pub fn add_virtual_affine_point_target<C: Curve>(&mut self) -> AffinePointTarget<C> {
let x = self.add_virtual_nonnative_target();
let y = self.add_virtual_nonnative_target();
AffinePointTarget { x, y }
}
pub fn curve_assert_valid<C: Curve>(&mut self, p: &AffinePointTarget<C>) {
let a = self.constant_nonnative(C::A);
let b = self.constant_nonnative(C::B);
let y_squared = self.mul_nonnative(&p.y, &p.y);
let x_squared = self.mul_nonnative(&p.x, &p.x);
let x_cubed = self.mul_nonnative(&x_squared, &p.x);
let a_x = self.mul_nonnative(&a, &p.x);
let a_x_plus_b = self.add_nonnative(&a_x, &b);
let rhs = self.add_nonnative(&x_cubed, &a_x_plus_b);
self.connect_nonnative(&y_squared, &rhs);
}
pub fn curve_neg<C: Curve>(&mut self, p: &AffinePointTarget<C>) -> AffinePointTarget<C> {
let neg_y = self.neg_nonnative(&p.y);
AffinePointTarget {
x: p.x.clone(),
y: neg_y,
}
}
pub fn curve_double<C: Curve>(&mut self, p: &AffinePointTarget<C>) -> AffinePointTarget<C> {
let AffinePointTarget { x, y } = p;
let double_y = self.add_nonnative(y, y);
let inv_double_y = self.inv_nonnative(&double_y);
let x_squared = self.mul_nonnative(x, x);
let double_x_squared = self.add_nonnative(&x_squared, &x_squared);
let triple_x_squared = self.add_nonnative(&double_x_squared, &x_squared);
let a = self.constant_nonnative(C::A);
let triple_xx_a = self.add_nonnative(&triple_x_squared, &a);
let lambda = self.mul_nonnative(&triple_xx_a, &inv_double_y);
let lambda_squared = self.mul_nonnative(&lambda, &lambda);
let x_double = self.add_nonnative(x, x);
let x3 = self.sub_nonnative(&lambda_squared, &x_double);
let x_diff = self.sub_nonnative(x, &x3);
let lambda_x_diff = self.mul_nonnative(&lambda, &x_diff);
let y3 = self.sub_nonnative(&lambda_x_diff, y);
AffinePointTarget { x: x3, y: y3 }
}
// Add two points, which are assumed to be non-equal.
pub fn curve_add<C: Curve>(
&mut self,
p1: &AffinePointTarget<C>,
p2: &AffinePointTarget<C>,
) -> AffinePointTarget<C> {
let AffinePointTarget { x: x1, y: y1 } = p1;
let AffinePointTarget { x: x2, y: y2 } = p2;
let u = self.sub_nonnative(y2, y1);
let uu = self.mul_nonnative(&u, &u);
let v = self.sub_nonnative(x2, x1);
let vv = self.mul_nonnative(&v, &v);
let vvv = self.mul_nonnative(&v, &vv);
let r = self.mul_nonnative(&vv, x1);
let diff = self.sub_nonnative(&uu, &vvv);
let r2 = self.add_nonnative(&r, &r);
let a = self.sub_nonnative(&diff, &r2);
let x3 = self.mul_nonnative(&v, &a);
let r_a = self.sub_nonnative(&r, &a);
let y3_first = self.mul_nonnative(&u, &r_a);
let y3_second = self.mul_nonnative(&vvv, y1);
let y3 = self.sub_nonnative(&y3_first, &y3_second);
let z3_inv = self.inv_nonnative(&vvv);
let x3_norm = self.mul_nonnative(&x3, &z3_inv);
let y3_norm = self.mul_nonnative(&y3, &z3_inv);
AffinePointTarget {
x: x3_norm,
y: y3_norm,
}
}
pub fn curve_scalar_mul<C: Curve>(
&mut self,
p: &AffinePointTarget<C>,
n: &NonNativeTarget<C::ScalarField>,
) -> AffinePointTarget<C> {
let one = self.constant_nonnative(C::BaseField::ONE);
let bits = self.split_nonnative_to_bits(n);
let bits_as_base: Vec<NonNativeTarget<C::BaseField>> =
bits.iter().map(|b| self.bool_to_nonnative(b)).collect();
let rando = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine();
let randot = self.constant_affine_point(rando);
// Result starts at `rando`, which is later subtracted, because we don't support arithmetic with the zero point.
let mut result = self.add_virtual_affine_point_target();
self.connect_affine_point(&randot, &result);
let mut two_i_times_p = self.add_virtual_affine_point_target();
self.connect_affine_point(p, &two_i_times_p);
for bit in bits_as_base.iter() {
let not_bit = self.sub_nonnative(&one, bit);
let result_plus_2_i_p = self.curve_add(&result, &two_i_times_p);
let new_x_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p.x);
let new_x_if_not_bit = self.mul_nonnative(&not_bit, &result.x);
let new_y_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p.y);
let new_y_if_not_bit = self.mul_nonnative(&not_bit, &result.y);
let new_x = self.add_nonnative(&new_x_if_bit, &new_x_if_not_bit);
let new_y = self.add_nonnative(&new_y_if_bit, &new_y_if_not_bit);
result = AffinePointTarget { x: new_x, y: new_y };
two_i_times_p = self.curve_double(&two_i_times_p);
}
// Subtract off result's intial value of `rando`.
let neg_r = self.curve_neg(&randot);
result = self.curve_add(&result, &neg_r);
result
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar};
use crate::curve::secp256k1::Secp256K1;
use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField;
use crate::field::secp256k1_base::Secp256K1Base;
use crate::field::secp256k1_scalar::Secp256K1Scalar;
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::verifier::verify;
#[test]
fn test_curve_point_is_valid() -> Result<()> {
type F = GoldilocksField;
const D: usize = 4;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let g = Secp256K1::GENERATOR_AFFINE;
let g_target = builder.constant_affine_point(g);
let neg_g_target = builder.curve_neg(&g_target);
builder.curve_assert_valid(&g_target);
builder.curve_assert_valid(&neg_g_target);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
#[should_panic]
fn test_curve_point_is_not_valid() {
type F = GoldilocksField;
const D: usize = 4;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let g = Secp256K1::GENERATOR_AFFINE;
let not_g = AffinePoint::<Secp256K1> {
x: g.x,
y: g.y + Secp256K1Base::ONE,
zero: g.zero,
};
let not_g_target = builder.constant_affine_point(not_g);
builder.curve_assert_valid(&not_g_target);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common).unwrap();
}
#[test]
fn test_curve_double() -> Result<()> {
type F = GoldilocksField;
const D: usize = 4;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let g = Secp256K1::GENERATOR_AFFINE;
let g_target = builder.constant_affine_point(g);
let neg_g_target = builder.curve_neg(&g_target);
let double_g = g.double();
let double_g_expected = builder.constant_affine_point(double_g);
builder.curve_assert_valid(&double_g_expected);
let double_neg_g = (-g).double();
let double_neg_g_expected = builder.constant_affine_point(double_neg_g);
builder.curve_assert_valid(&double_neg_g_expected);
let double_g_actual = builder.curve_double(&g_target);
let double_neg_g_actual = builder.curve_double(&neg_g_target);
builder.curve_assert_valid(&double_g_actual);
builder.curve_assert_valid(&double_neg_g_actual);
builder.connect_affine_point(&double_g_expected, &double_g_actual);
builder.connect_affine_point(&double_neg_g_expected, &double_neg_g_actual);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_curve_add() -> Result<()> {
type F = GoldilocksField;
const D: usize = 4;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let g = Secp256K1::GENERATOR_AFFINE;
let double_g = g.double();
let g_plus_2g = (g + double_g).to_affine();
let g_plus_2g_expected = builder.constant_affine_point(g_plus_2g);
builder.curve_assert_valid(&g_plus_2g_expected);
let g_target = builder.constant_affine_point(g);
let double_g_target = builder.curve_double(&g_target);
let g_plus_2g_actual = builder.curve_add(&g_target, &double_g_target);
builder.curve_assert_valid(&g_plus_2g_actual);
builder.connect_affine_point(&g_plus_2g_expected, &g_plus_2g_actual);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
#[ignore]
fn test_curve_mul() -> Result<()> {
type F = GoldilocksField;
const D: usize = 4;
let config = CircuitConfig {
num_routed_wires: 33,
..CircuitConfig::standard_recursion_config()
};
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let g = Secp256K1::GENERATOR_AFFINE;
let five = Secp256K1Scalar::from_canonical_usize(5);
let five_scalar = CurveScalar::<Secp256K1>(five);
let five_g = (five_scalar * g.to_projective()).to_affine();
let five_g_expected = builder.constant_affine_point(five_g);
builder.curve_assert_valid(&five_g_expected);
let g_target = builder.constant_affine_point(g);
let five_target = builder.constant_nonnative(five);
let five_g_actual = builder.curve_scalar_mul(&g_target, &five_target);
builder.curve_assert_valid(&five_g_actual);
builder.connect_affine_point(&five_g_expected, &five_g_actual);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
#[ignore]
fn test_curve_random() -> Result<()> {
type F = GoldilocksField;
const D: usize = 4;
let config = CircuitConfig {
num_routed_wires: 33,
..CircuitConfig::standard_recursion_config()
};
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let rando =
(CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine();
let randot = builder.constant_affine_point(rando);
let two_target = builder.constant_nonnative(Secp256K1Scalar::TWO);
let randot_doubled = builder.curve_double(&randot);
let randot_times_two = builder.curve_scalar_mul(&randot, &two_target);
builder.connect_affine_point(&randot_doubled, &randot_times_two);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
}

View File

@ -1,22 +1,94 @@
use std::ops::Range;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField;
use crate::gates::gate::Gate;
use crate::gates::interpolation::InterpolationGate;
use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Trait for gates which interpolate a polynomial, whose points are a (base field) coset of the multiplicative subgroup
/// with the given size, and whose values are extension field elements, given by input wires.
/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point.
pub(crate) trait InterpolationGate<F: Extendable<D>, const D: usize>:
Gate<F, D> + Copy
{
fn new(subgroup_bits: usize) -> Self;
fn num_points(&self) -> usize;
/// Wire index of the coset shift.
fn wire_shift(&self) -> usize {
0
}
fn start_values(&self) -> usize {
1
}
/// Wire indices of the `i`th interpolant value.
fn wires_value(&self, i: usize) -> Range<usize> {
debug_assert!(i < self.num_points());
let start = self.start_values() + i * D;
start..start + D
}
fn start_evaluation_point(&self) -> usize {
self.start_values() + self.num_points() * D
}
/// Wire indices of the point to evaluate the interpolant at.
fn wires_evaluation_point(&self) -> Range<usize> {
let start = self.start_evaluation_point();
start..start + D
}
fn start_evaluation_value(&self) -> usize {
self.start_evaluation_point() + D
}
/// Wire indices of the interpolated value.
fn wires_evaluation_value(&self) -> Range<usize> {
let start = self.start_evaluation_value();
start..start + D
}
fn start_coeffs(&self) -> usize {
self.start_evaluation_value() + D
}
/// The number of routed wires required in the typical usage of this gate, where the points to
/// interpolate, the evaluation point, and the corresponding value are all routed.
fn num_routed_wires(&self) -> usize {
self.start_coeffs()
}
/// Wire indices of the interpolant's `i`th coefficient.
fn wires_coeff(&self, i: usize) -> Range<usize> {
debug_assert!(i < self.num_points());
let start = self.start_coeffs() + i * D;
start..start + D
}
fn end_coeffs(&self) -> usize {
self.start_coeffs() + D * self.num_points()
}
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Interpolates a polynomial, whose points are a coset of the multiplicative subgroup with the
/// given size, and whose values are given. Returns the evaluation of the interpolant at
/// `evaluation_point`.
pub fn interpolate_coset(
pub(crate) fn interpolate_coset<G: InterpolationGate<F, D>>(
&mut self,
subgroup_bits: usize,
coset_shift: Target,
values: &[ExtensionTarget<D>],
evaluation_point: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let gate = InterpolationGate::new(subgroup_bits);
let gate_index = self.add_gate(gate.clone(), vec![]);
let gate = G::new(subgroup_bits);
let gate_index = self.add_gate(gate, vec![]);
self.connect(coset_shift, Target::wire(gate_index, gate.wire_shift()));
for (i, &v) in values.iter().enumerate() {
self.connect_extension(
@ -37,6 +109,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
mod tests {
use anyhow::Result;
use crate::field::extension_field::quadratic::QuadraticExtension;
use crate::field::extension_field::FieldExtension;
use crate::field::field_types::Field;
use crate::field::interpolation::interpolant;
@ -83,9 +156,21 @@ mod tests {
let zt = builder.constant_extension(z);
let eval = builder.interpolate_coset(subgroup_bits, coset_shift_target, &value_targets, zt);
let eval_hd = builder.interpolate_coset::<HighDegreeInterpolationGate<F, D>>(
subgroup_bits,
coset_shift_target,
&value_targets,
zt,
);
let eval_ld = builder.interpolate_coset::<LowDegreeInterpolationGate<F, D>>(
subgroup_bits,
coset_shift_target,
&value_targets,
zt,
);
let true_eval_target = builder.constant_extension(true_eval);
builder.connect_extension(eval, true_eval_target);
builder.connect_extension(eval_hd, true_eval_target);
builder.connect_extension(eval_ld, true_eval_target);
let data = builder.build::<C>();
let proof = data.prove(pw)?;

View File

@ -1,8 +1,13 @@
pub mod arithmetic;
pub mod arithmetic_extension;
pub mod arithmetic_u32;
pub mod biguint;
pub mod curve;
pub mod hash;
pub mod insert;
pub mod interpolation;
pub mod multiple_comparison;
pub mod nonnative;
pub mod permutation;
pub mod polynomial;
pub mod random_access;

View File

@ -0,0 +1,138 @@
use super::arithmetic_u32::U32Target;
use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField;
use crate::gates::comparison::ComparisonGate;
use crate::iop::target::{BoolTarget, Target};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::util::ceil_div_usize;
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Returns true if a is less than or equal to b, considered as base-`2^num_bits` limbs of a large value.
/// This range-checks its inputs.
pub fn list_le(&mut self, a: Vec<Target>, b: Vec<Target>, num_bits: usize) -> BoolTarget {
assert_eq!(
a.len(),
b.len(),
"Comparison must be between same number of inputs and outputs"
);
let n = a.len();
let chunk_bits = 2;
let num_chunks = ceil_div_usize(num_bits, chunk_bits);
let one = self.one();
let mut result = one;
for i in 0..n {
let a_le_b_gate = ComparisonGate::new(num_bits, num_chunks);
let a_le_b_gate_index = self.add_gate(a_le_b_gate.clone(), vec![]);
self.connect(
Target::wire(a_le_b_gate_index, a_le_b_gate.wire_first_input()),
a[i],
);
self.connect(
Target::wire(a_le_b_gate_index, a_le_b_gate.wire_second_input()),
b[i],
);
let a_le_b_result = Target::wire(a_le_b_gate_index, a_le_b_gate.wire_result_bool());
let b_le_a_gate = ComparisonGate::new(num_bits, num_chunks);
let b_le_a_gate_index = self.add_gate(b_le_a_gate.clone(), vec![]);
self.connect(
Target::wire(b_le_a_gate_index, b_le_a_gate.wire_first_input()),
b[i],
);
self.connect(
Target::wire(b_le_a_gate_index, b_le_a_gate.wire_second_input()),
a[i],
);
let b_le_a_result = Target::wire(b_le_a_gate_index, b_le_a_gate.wire_result_bool());
let these_limbs_equal = self.mul(a_le_b_result, b_le_a_result);
let these_limbs_less_than = self.sub(one, b_le_a_result);
result = self.mul_add(these_limbs_equal, result, these_limbs_less_than);
}
// `result` being boolean is an invariant, maintained because its new value is always
// `x * result + y`, where `x` and `y` are booleans that are not simultaneously true.
BoolTarget::new_unsafe(result)
}
/// Helper function for comparing, specifically, lists of `U32Target`s.
pub fn list_le_u32(&mut self, a: Vec<U32Target>, b: Vec<U32Target>) -> BoolTarget {
let a_targets = a.iter().map(|&t| t.0).collect();
let b_targets = b.iter().map(|&t| t.0).collect();
self.list_le(a_targets, b_targets, 32)
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use num::BigUint;
use rand::Rng;
use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField;
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::verifier::verify;
fn test_list_le(size: usize, num_bits: usize) -> Result<()> {
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let mut rng = rand::thread_rng();
let lst1: Vec<u64> = (0..size)
.map(|_| rng.gen_range(0..(1 << num_bits)))
.collect();
let lst2: Vec<u64> = (0..size)
.map(|_| rng.gen_range(0..(1 << num_bits)))
.collect();
let a_biguint = BigUint::from_slice(
&lst1
.iter()
.flat_map(|&x| [x as u32, (x >> 32) as u32])
.collect::<Vec<_>>(),
);
let b_biguint = BigUint::from_slice(
&lst2
.iter()
.flat_map(|&x| [x as u32, (x >> 32) as u32])
.collect::<Vec<_>>(),
);
let a = lst1
.iter()
.map(|&x| builder.constant(F::from_canonical_u64(x)))
.collect();
let b = lst2
.iter()
.map(|&x| builder.constant(F::from_canonical_u64(x)))
.collect();
let result = builder.list_le(a, b, num_bits);
let expected_result = builder.constant_bool(a_biguint <= b_biguint);
builder.connect(result.target, expected_result.target);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_multiple_comparison() -> Result<()> {
for size in [1, 3, 6] {
for num_bits in [20, 32, 40, 44] {
test_list_le(size, num_bits).unwrap();
}
}
Ok(())
}
}

342
src/gadgets/nonnative.rs Normal file
View File

@ -0,0 +1,342 @@
use std::marker::PhantomData;
use num::{BigUint, Zero};
use crate::field::field_types::RichField;
use crate::field::{extension_field::Extendable, field_types::Field};
use crate::gadgets::arithmetic_u32::U32Target;
use crate::gadgets::biguint::BigUintTarget;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::{BoolTarget, Target};
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::util::ceil_div_usize;
#[derive(Clone, Debug)]
pub struct NonNativeTarget<FF: Field> {
pub(crate) value: BigUintTarget,
_phantom: PhantomData<FF>,
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
fn num_nonnative_limbs<FF: Field>() -> usize {
ceil_div_usize(FF::BITS, 32)
}
pub fn biguint_to_nonnative<FF: Field>(&mut self, x: &BigUintTarget) -> NonNativeTarget<FF> {
NonNativeTarget {
value: x.clone(),
_phantom: PhantomData,
}
}
pub fn nonnative_to_biguint<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> BigUintTarget {
x.value.clone()
}
pub fn constant_nonnative<FF: Field>(&mut self, x: FF) -> NonNativeTarget<FF> {
let x_biguint = self.constant_biguint(&x.to_biguint());
self.biguint_to_nonnative(&x_biguint)
}
// Assert that two NonNativeTarget's, both assumed to be in reduced form, are equal.
pub fn connect_nonnative<FF: Field>(
&mut self,
lhs: &NonNativeTarget<FF>,
rhs: &NonNativeTarget<FF>,
) {
self.connect_biguint(&lhs.value, &rhs.value);
}
pub fn add_virtual_nonnative_target<FF: Field>(&mut self) -> NonNativeTarget<FF> {
let num_limbs = Self::num_nonnative_limbs::<FF>();
let value = self.add_virtual_biguint_target(num_limbs);
NonNativeTarget {
value,
_phantom: PhantomData,
}
}
// Add two `NonNativeTarget`s.
pub fn add_nonnative<FF: Field>(
&mut self,
a: &NonNativeTarget<FF>,
b: &NonNativeTarget<FF>,
) -> NonNativeTarget<FF> {
let result = self.add_biguint(&a.value, &b.value);
// TODO: reduce add result with only one conditional subtraction
self.reduce(&result)
}
// Subtract two `NonNativeTarget`s.
pub fn sub_nonnative<FF: Field>(
&mut self,
a: &NonNativeTarget<FF>,
b: &NonNativeTarget<FF>,
) -> NonNativeTarget<FF> {
let order = self.constant_biguint(&FF::order());
let a_plus_order = self.add_biguint(&order, &a.value);
let result = self.sub_biguint(&a_plus_order, &b.value);
// TODO: reduce sub result with only one conditional addition?
self.reduce(&result)
}
pub fn mul_nonnative<FF: Field>(
&mut self,
a: &NonNativeTarget<FF>,
b: &NonNativeTarget<FF>,
) -> NonNativeTarget<FF> {
let result = self.mul_biguint(&a.value, &b.value);
self.reduce(&result)
}
pub fn neg_nonnative<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> NonNativeTarget<FF> {
let zero_target = self.constant_biguint(&BigUint::zero());
let zero_ff = self.biguint_to_nonnative(&zero_target);
self.sub_nonnative(&zero_ff, x)
}
pub fn inv_nonnative<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> NonNativeTarget<FF> {
let num_limbs = x.value.num_limbs();
let inv_biguint = self.add_virtual_biguint_target(num_limbs);
let inv = NonNativeTarget::<FF> {
value: inv_biguint,
_phantom: PhantomData,
};
self.add_simple_generator(NonNativeInverseGenerator::<F, D, FF> {
x: x.clone(),
inv: inv.clone(),
_phantom: PhantomData,
});
let product = self.mul_nonnative(x, &inv);
let one = self.constant_nonnative(FF::ONE);
self.connect_nonnative(&product, &one);
inv
}
pub fn div_rem_nonnative<FF: Field>(
&mut self,
x: &NonNativeTarget<FF>,
y: &NonNativeTarget<FF>,
) -> (NonNativeTarget<FF>, NonNativeTarget<FF>) {
let x_biguint = self.nonnative_to_biguint(x);
let y_biguint = self.nonnative_to_biguint(y);
let (div_biguint, rem_biguint) = self.div_rem_biguint(&x_biguint, &y_biguint);
let div = self.biguint_to_nonnative(&div_biguint);
let rem = self.biguint_to_nonnative(&rem_biguint);
(div, rem)
}
/// Returns `x % |FF|` as a `NonNativeTarget`.
fn reduce<FF: Field>(&mut self, x: &BigUintTarget) -> NonNativeTarget<FF> {
let modulus = FF::order();
let order_target = self.constant_biguint(&modulus);
let value = self.rem_biguint(x, &order_target);
NonNativeTarget {
value,
_phantom: PhantomData,
}
}
#[allow(dead_code)]
fn reduce_nonnative<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> NonNativeTarget<FF> {
let x_biguint = self.nonnative_to_biguint(x);
self.reduce(&x_biguint)
}
pub fn bool_to_nonnative<FF: Field>(&mut self, b: &BoolTarget) -> NonNativeTarget<FF> {
let limbs = vec![U32Target(b.target)];
let value = BigUintTarget { limbs };
NonNativeTarget {
value,
_phantom: PhantomData,
}
}
// Split a nonnative field element to bits.
pub fn split_nonnative_to_bits<FF: Field>(
&mut self,
x: &NonNativeTarget<FF>,
) -> Vec<BoolTarget> {
let num_limbs = x.value.num_limbs();
let mut result = Vec::with_capacity(num_limbs * 32);
for i in 0..num_limbs {
let limb = x.value.get_limb(i);
let bit_targets = self.split_le_base::<2>(limb.0, 32);
let mut bits: Vec<_> = bit_targets
.iter()
.map(|&t| BoolTarget::new_unsafe(t))
.collect();
result.append(&mut bits);
}
result
}
}
#[derive(Debug)]
struct NonNativeInverseGenerator<F: RichField + Extendable<D>, const D: usize, FF: Field> {
x: NonNativeTarget<FF>,
inv: NonNativeTarget<FF>,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize, FF: Field> SimpleGenerator<F>
for NonNativeInverseGenerator<F, D, FF>
{
fn dependencies(&self) -> Vec<Target> {
self.x.value.limbs.iter().map(|&l| l.0).collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let x = witness.get_nonnative_target(self.x.clone());
let inv = x.inverse();
out_buffer.set_nonnative_target(self.inv.clone(), inv);
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField;
use crate::field::secp256k1_base::Secp256K1Base;
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::verifier::verify;
#[test]
fn test_nonnative_add() -> Result<()> {
type FF = Secp256K1Base;
let x_ff = FF::rand();
let y_ff = FF::rand();
let sum_ff = x_ff + y_ff;
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_nonnative(x_ff);
let y = builder.constant_nonnative(y_ff);
let sum = builder.add_nonnative(&x, &y);
let sum_expected = builder.constant_nonnative(sum_ff);
builder.connect_nonnative(&sum, &sum_expected);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_nonnative_sub() -> Result<()> {
type FF = Secp256K1Base;
let x_ff = FF::rand();
let mut y_ff = FF::rand();
while y_ff.to_biguint() > x_ff.to_biguint() {
y_ff = FF::rand();
}
let diff_ff = x_ff - y_ff;
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_nonnative(x_ff);
let y = builder.constant_nonnative(y_ff);
let diff = builder.sub_nonnative(&x, &y);
let diff_expected = builder.constant_nonnative(diff_ff);
builder.connect_nonnative(&diff, &diff_expected);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_nonnative_mul() -> Result<()> {
type FF = Secp256K1Base;
let x_ff = FF::rand();
let y_ff = FF::rand();
let product_ff = x_ff * y_ff;
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_nonnative(x_ff);
let y = builder.constant_nonnative(y_ff);
let product = builder.mul_nonnative(&x, &y);
let product_expected = builder.constant_nonnative(product_ff);
builder.connect_nonnative(&product, &product_expected);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_nonnative_neg() -> Result<()> {
type FF = Secp256K1Base;
let x_ff = FF::rand();
let neg_x_ff = -x_ff;
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_nonnative(x_ff);
let neg_x = builder.neg_nonnative(&x);
let neg_x_expected = builder.constant_nonnative(neg_x_ff);
builder.connect_nonnative(&neg_x, &neg_x_expected);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_nonnative_inv() -> Result<()> {
type FF = Secp256K1Base;
let x_ff = FF::rand();
let inv_x_ff = x_ff.inverse();
type F = GoldilocksField;
let config = CircuitConfig::standard_recursion_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_nonnative(x_ff);
let inv_x = builder.inv_nonnative(&x);
let inv_x_expected = builder.constant_nonnative(inv_x_ff);
builder.connect_nonnative(&inv_x, &inv_x_expected);
let data = builder.build();
let proof = data.prove(pw).unwrap();
verify(proof, &data.verifier_only, &data.common)
}
}

View File

@ -2,7 +2,6 @@ use std::collections::BTreeMap;
use std::marker::PhantomData;
use crate::field::{extension_field::Extendable, field_types::Field};
use crate::gates::switch::SwitchGate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness};
@ -34,7 +33,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.assert_permutation_2x2(a[0].clone(), a[1].clone(), b[0].clone(), b[1].clone())
}
// For larger lists, we recursively use two smaller permutation networks.
//_ => self.assert_permutation_recursive(a, b)
_ => self.assert_permutation_recursive(a, b),
}
}
@ -72,22 +70,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let chunk_size = a1.len();
if self.current_switch_gates.len() < chunk_size {
self.current_switch_gates
.extend(vec![None; chunk_size - self.current_switch_gates.len()]);
}
let (gate, gate_index, mut next_copy) =
match self.current_switch_gates[chunk_size - 1].clone() {
None => {
let gate = SwitchGate::<F, D>::new_from_config(&self.config, chunk_size);
let gate_index = self.add_gate(gate.clone(), vec![]);
(gate, gate_index, 0)
}
Some((gate, idx, next_copy)) => (gate, idx, next_copy),
};
let num_copies = gate.num_copies;
let (gate, gate_index, next_copy) = self.find_switch_gate(chunk_size);
let mut c = Vec::new();
let mut d = Vec::new();
@ -112,13 +95,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let switch = Target::wire(gate_index, gate.wire_switch_bool(next_copy));
next_copy += 1;
if next_copy == num_copies {
self.current_switch_gates[chunk_size - 1] = None;
} else {
self.current_switch_gates[chunk_size - 1] = Some((gate, gate_index, next_copy));
}
(switch, c, d)
}
@ -402,7 +378,7 @@ mod tests {
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let lst: Vec<F> = (0..size * 2).map(|n| F::from_canonical_usize(n)).collect();
let lst: Vec<F> = (0..size * 2).map(F::from_canonical_usize).collect();
let a: Vec<Vec<Target>> = lst[..]
.chunks(2)
.map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])])

View File

@ -63,4 +63,21 @@ impl<const D: usize> PolynomialCoeffsExtAlgebraTarget<D> {
}
acc
}
/// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1.
pub fn eval_with_powers<F>(
&self,
builder: &mut CircuitBuilder<F, D>,
powers: &[ExtensionAlgebraTarget<D>],
) -> ExtensionAlgebraTarget<D>
where
F: RichField + Extendable<D>,
{
debug_assert_eq!(self.0.len(), powers.len() + 1);
let acc = self.0[0];
self.0[1..]
.iter()
.zip(powers)
.fold(acc, |acc, (&x, &c)| builder.mul_add_ext_algebra(c, x, acc))
}
}

View File

@ -3,49 +3,20 @@ use crate::field::extension_field::Extendable;
use crate::gates::random_access::RandomAccessGate;
use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::util::log2_strict;
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Finds the last available random access gate with the given `vec_size` or add one if there aren't any.
/// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index
/// `g` and the gate's `i`-th random access is available.
fn find_random_access_gate(&mut self, vec_size: usize) -> (usize, usize) {
let (gate, i) = self
.free_random_access
.get(&vec_size)
.copied()
.unwrap_or_else(|| {
let gate = self.add_gate(
RandomAccessGate::new_from_config(&self.config, vec_size),
vec![],
);
(gate, 0)
});
// Update `free_random_access` with new values.
if i < RandomAccessGate::<F, D>::max_num_copies(
self.config.num_routed_wires,
self.config.num_wires,
vec_size,
) - 1
{
self.free_random_access.insert(vec_size, (gate, i + 1));
} else {
self.free_random_access.remove(&vec_size);
}
(gate, i)
}
/// Checks that a `Target` matches a vector at a non-deterministic index.
/// Note: `access_index` is not range-checked.
pub fn random_access(&mut self, access_index: Target, claimed_element: Target, v: Vec<Target>) {
let vec_size = v.len();
let bits = log2_strict(vec_size);
debug_assert!(vec_size > 0);
if vec_size == 1 {
return self.connect(claimed_element, v[0]);
}
let (gate_index, copy) = self.find_random_access_gate(vec_size);
let dummy_gate = RandomAccessGate::<F, D>::new_from_config(&self.config, vec_size);
let (gate_index, copy) = self.find_random_access_gate(bits);
let dummy_gate = RandomAccessGate::<F, D>::new_from_config(&self.config, bits);
v.iter().enumerate().for_each(|(i, &val)| {
self.connect(

View File

@ -3,6 +3,8 @@ use std::marker::PhantomData;
use itertools::izip;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, RichField};
use crate::gates::assert_le::AssertLessThanGate;
use crate::field::field_types::Field;
use crate::gates::comparison::ComparisonGate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
@ -40,9 +42,9 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.assert_permutation(a_chunks, b_chunks);
}
/// Add a ComparisonGate to assert that `lhs` is less than `rhs`, where their values are at most `bits` bits.
/// Add an AssertLessThanGate to assert that `lhs` is less than `rhs`, where their values are at most `bits` bits.
pub fn assert_le(&mut self, lhs: Target, rhs: Target, bits: usize, num_chunks: usize) {
let gate = ComparisonGate::new(bits, num_chunks);
let gate = AssertLessThanGate::new(bits, num_chunks);
let gate_index = self.add_gate(gate.clone(), vec![]);
self.connect(Target::wire(gate_index, gate.wire_first_input()), lhs);
@ -126,8 +128,7 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for MemoryOpSortGenera
fn dependencies(&self) -> Vec<Target> {
self.input_ops
.iter()
.map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value])
.flatten()
.flat_map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value])
.collect()
}
@ -223,7 +224,7 @@ mod tests {
izip!(is_write_vals, address_vals, timestamp_vals, value_vals)
.zip(combined_vals_u64)
.collect::<Vec<_>>();
input_ops_and_keys.sort_by_key(|(_, val)| val.clone());
input_ops_and_keys.sort_by_key(|(_, val)| *val);
let input_ops_sorted: Vec<_> = input_ops_and_keys.iter().map(|(x, _)| x).collect();
let output_ops =

View File

@ -1,5 +1,7 @@
use std::borrow::Borrow;
use itertools::Itertools;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::gates::base_sum::BaseSumGate;
@ -27,23 +29,25 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Takes an iterator of bits `(b_i)` and returns `sum b_i * 2^i`, i.e.,
/// the number with little-endian bit representation given by `bits`.
pub(crate) fn le_sum(
&mut self,
bits: impl ExactSizeIterator<Item = impl Borrow<BoolTarget>> + Clone,
) -> Target {
pub(crate) fn le_sum(&mut self, bits: impl Iterator<Item = impl Borrow<BoolTarget>>) -> Target {
let bits = bits.map(|b| *b.borrow()).collect_vec();
let num_bits = bits.len();
if num_bits == 0 {
return self.zero();
} else if num_bits == 1 {
let mut bits = bits;
return bits.next().unwrap().borrow().target;
} else if num_bits == 2 {
let two = self.two();
let mut bits = bits;
let b0 = bits.next().unwrap().borrow().target;
let b1 = bits.next().unwrap().borrow().target;
return self.mul_add(two, b1, b0);
}
// Check if it's cheaper to just do this with arithmetic operations.
let arithmetic_ops = num_bits - 1;
if arithmetic_ops <= self.num_base_arithmetic_ops_per_gate() {
let two = self.two();
let mut rev_bits = bits.iter().rev();
let mut sum = rev_bits.next().unwrap().target;
for &bit in rev_bits {
sum = self.mul_add(two, sum, bit.target);
}
return sum;
}
debug_assert!(
BaseSumGate::<2>::START_LIMBS + num_bits <= self.config.num_routed_wires,
"Not enough routed wires."
@ -51,10 +55,10 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let gate_type = BaseSumGate::<2>::new_from_config::<F>(&self.config);
let gate_index = self.add_gate(gate_type, vec![]);
for (limb, wire) in bits
.clone()
.iter()
.zip(BaseSumGate::<2>::START_LIMBS..BaseSumGate::<2>::START_LIMBS + num_bits)
{
self.connect(limb.borrow().target, Target::wire(gate_index, wire));
self.connect(limb.target, Target::wire(gate_index, wire));
}
for l in gate_type.limbs().skip(num_bits) {
self.assert_zero(Target::wire(gate_index, l));
@ -62,7 +66,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.add_simple_generator(BaseSumGenerator::<2> {
gate_index,
limbs: bits.map(|l| *l.borrow()).collect(),
limbs: bits,
});
Target::wire(gate_index, BaseSumGate::<2>::WIRE_SUM)
@ -146,14 +150,14 @@ mod tests {
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
let n = thread_rng().gen_range(0..(1 << 10));
let n = thread_rng().gen_range(0..(1 << 30));
let x = builder.constant(F::from_canonical_usize(n));
let zero = builder._false();
let one = builder._true();
let y = builder.le_sum(
(0..10)
(0..30)
.scan(n, |acc, _| {
let tmp = *acc % 2;
*acc /= 2;

View File

@ -24,8 +24,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let mut bits = Vec::with_capacity(num_bits);
for &gate in &gates {
let start_limbs = BaseSumGate::<2>::START_LIMBS;
for limb_input in start_limbs..start_limbs + gate_type.num_limbs {
for limb_input in gate_type.limbs() {
// `new_unsafe` is safe here because BaseSumGate::<2> forces it to be in `{0, 1}`.
bits.push(BoolTarget::new_unsafe(Target::wire(gate, limb_input)));
}
@ -35,10 +34,11 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
let zero = self.zero();
let base = F::TWO.exp_u64(gate_type.num_limbs as u64);
let mut acc = zero;
for &gate in gates.iter().rev() {
let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM);
acc = self.mul_const_add(F::from_canonical_usize(1 << gate_type.num_limbs), acc, sum);
acc = self.mul_const_add(base, acc, sum);
}
self.connect(acc, integer);
@ -96,11 +96,18 @@ impl<F: RichField> SimpleGenerator<F> for WireSplitGenerator {
for &gate in &self.gates {
let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM);
out_buffer.set_target(
sum,
F::from_canonical_u64(integer_value & ((1 << self.num_limbs) - 1)),
);
integer_value >>= self.num_limbs;
// If num_limbs >= 64, we don't need to truncate since `integer_value` is already
// limited to 64 bits, and trying to do so would cause overflow. Hence the conditional.
let mut truncated_value = integer_value;
if self.num_limbs < 64 {
truncated_value = integer_value & ((1 << self.num_limbs) - 1);
integer_value >>= self.num_limbs;
} else {
integer_value = 0;
};
out_buffer.set_target(sum, F::from_canonical_u64(truncated_value));
}
debug_assert_eq!(

View File

@ -0,0 +1,212 @@
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField;
use crate::gates::gate::Gate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// A gate which can perform a weighted multiply-add, i.e. `result = c0 x y + c1 z`. If the config
/// supports enough routed wires, it can support several such operations in one gate.
#[derive(Debug)]
pub struct ArithmeticGate {
/// Number of arithmetic operations performed by an arithmetic gate.
pub num_ops: usize,
}
impl ArithmeticGate {
pub fn new_from_config(config: &CircuitConfig) -> Self {
Self {
num_ops: Self::num_ops(config),
}
}
/// Determine the maximum number of operations that can fit in one gate for the given config.
pub(crate) fn num_ops(config: &CircuitConfig) -> usize {
let wires_per_op = 4;
config.num_routed_wires / wires_per_op
}
pub fn wire_ith_multiplicand_0(i: usize) -> usize {
4 * i
}
pub fn wire_ith_multiplicand_1(i: usize) -> usize {
4 * i + 1
}
pub fn wire_ith_addend(i: usize) -> usize {
4 * i + 2
}
pub fn wire_ith_output(i: usize) -> usize {
4 * i + 3
}
}
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticGate {
fn id(&self) -> String {
format!("{:?}", self)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let const_0 = vars.local_constants[0];
let const_1 = vars.local_constants[1];
let mut constraints = Vec::new();
for i in 0..self.num_ops {
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[Self::wire_ith_addend(i)];
let output = vars.local_wires[Self::wire_ith_output(i)];
let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1;
constraints.push(output - computed_output);
}
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let const_0 = vars.local_constants[0];
let const_1 = vars.local_constants[1];
let mut constraints = Vec::new();
for i in 0..self.num_ops {
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[Self::wire_ith_addend(i)];
let output = vars.local_wires[Self::wire_ith_output(i)];
let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1;
constraints.push(output - computed_output);
}
constraints
}
fn eval_unfiltered_recursively(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let const_0 = vars.local_constants[0];
let const_1 = vars.local_constants[1];
let mut constraints = Vec::new();
for i in 0..self.num_ops {
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[Self::wire_ith_addend(i)];
let output = vars.local_wires[Self::wire_ith_output(i)];
let computed_output = {
let scaled_mul =
builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]);
builder.mul_add_extension(const_1, addend, scaled_mul)
};
let diff = builder.sub_extension(output, computed_output);
constraints.push(diff);
}
constraints
}
fn generators(
&self,
gate_index: usize,
local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
(0..self.num_ops)
.map(|i| {
let g: Box<dyn WitnessGenerator<F>> = Box::new(
ArithmeticBaseGenerator {
gate_index,
const_0: local_constants[0],
const_1: local_constants[1],
i,
}
.adapter(),
);
g
})
.collect::<Vec<_>>()
}
fn num_wires(&self) -> usize {
self.num_ops * 4
}
fn num_constants(&self) -> usize {
2
}
fn degree(&self) -> usize {
3
}
fn num_constraints(&self) -> usize {
self.num_ops
}
}
#[derive(Clone, Debug)]
struct ArithmeticBaseGenerator<F: RichField + Extendable<D>, const D: usize> {
gate_index: usize,
const_0: F,
const_1: F,
i: usize,
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
for ArithmeticBaseGenerator<F, D>
{
fn dependencies(&self) -> Vec<Target> {
[
ArithmeticGate::wire_ith_multiplicand_0(self.i),
ArithmeticGate::wire_ith_multiplicand_1(self.i),
ArithmeticGate::wire_ith_addend(self.i),
]
.iter()
.map(|&i| Target::wire(self.gate_index, i))
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let get_wire =
|wire: usize| -> F { witness.get_target(Target::wire(self.gate_index, wire)) };
let multiplicand_0 = get_wire(ArithmeticGate::wire_ith_multiplicand_0(self.i));
let multiplicand_1 = get_wire(ArithmeticGate::wire_ith_multiplicand_1(self.i));
let addend = get_wire(ArithmeticGate::wire_ith_addend(self.i));
let output_target = Target::wire(self.gate_index, ArithmeticGate::wire_ith_output(self.i));
let computed_output =
multiplicand_0 * multiplicand_1 * self.const_0 + addend * self.const_1;
out_buffer.set_target(output_target, computed_output)
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::arithmetic_base::ArithmeticGate;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::plonk::circuit_data::CircuitConfig;
#[test]
fn low_degree() {
let gate = ArithmeticGate::new_from_config(&CircuitConfig::standard_recursion_config());
test_low_degree::<GoldilocksField, _, 4>(gate);
}
#[test]
fn eval_fns() -> Result<()> {
let gate = ArithmeticGate::new_from_config(&CircuitConfig::standard_recursion_config());
test_eval_fns::<GoldilocksField, _, 4>(gate)
}
}

View File

@ -11,7 +11,8 @@ use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// A gate which can a linear combination `c0*x*y+c1*z` twice with the same `x`.
/// A gate which can perform a weighted multiply-add, i.e. `result = c0 x y + c1 z`. If the config
/// supports enough routed wires, it can support several such operations in one gate.
#[derive(Debug)]
pub struct ArithmeticExtensionGate<const D: usize> {
/// Number of arithmetic operations performed by an arithmetic gate.
@ -203,7 +204,7 @@ mod tests {
use anyhow::Result;
use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::arithmetic::ArithmeticExtensionGate;
use crate::gates::arithmetic_extension::ArithmeticExtensionGate;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};

View File

@ -11,37 +11,49 @@ use crate::iop::target::Target;
use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// Number of arithmetic operations performed by an arithmetic gate.
pub const NUM_U32_ARITHMETIC_OPS: usize = 3;
/// A gate to perform a basic mul-add on 32-bit values (we assume they are range-checked beforehand).
#[derive(Debug)]
#[derive(Copy, Clone, Debug)]
pub struct U32ArithmeticGate<F: Extendable<D>, const D: usize> {
pub num_ops: usize,
_phantom: PhantomData<F>,
}
impl<F: Extendable<D>, const D: usize> U32ArithmeticGate<F, D> {
pub fn wire_ith_multiplicand_0(i: usize) -> usize {
debug_assert!(i < NUM_U32_ARITHMETIC_OPS);
pub fn new_from_config(config: &CircuitConfig) -> Self {
Self {
num_ops: Self::num_ops(config),
_phantom: PhantomData,
}
}
pub(crate) fn num_ops(config: &CircuitConfig) -> usize {
let wires_per_op = 5 + Self::num_limbs();
let routed_wires_per_op = 5;
(config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op)
}
pub fn wire_ith_multiplicand_0(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i
}
pub fn wire_ith_multiplicand_1(i: usize) -> usize {
debug_assert!(i < NUM_U32_ARITHMETIC_OPS);
pub fn wire_ith_multiplicand_1(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i + 1
}
pub fn wire_ith_addend(i: usize) -> usize {
debug_assert!(i < NUM_U32_ARITHMETIC_OPS);
pub fn wire_ith_addend(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i + 2
}
pub fn wire_ith_output_low_half(i: usize) -> usize {
debug_assert!(i < NUM_U32_ARITHMETIC_OPS);
pub fn wire_ith_output_low_half(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i + 3
}
pub fn wire_ith_output_high_half(i: usize) -> usize {
debug_assert!(i < NUM_U32_ARITHMETIC_OPS);
pub fn wire_ith_output_high_half(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i + 4
}
@ -52,10 +64,10 @@ impl<F: Extendable<D>, const D: usize> U32ArithmeticGate<F, D> {
64 / Self::limb_bits()
}
pub fn wire_ith_output_jth_limb(i: usize, j: usize) -> usize {
debug_assert!(i < NUM_U32_ARITHMETIC_OPS);
pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize {
debug_assert!(i < self.num_ops);
debug_assert!(j < Self::num_limbs());
5 * NUM_U32_ARITHMETIC_OPS + Self::num_limbs() * i + j
5 * self.num_ops + Self::num_limbs() * i + j
}
}
@ -66,15 +78,15 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let mut constraints = Vec::with_capacity(self.num_constraints());
for i in 0..NUM_U32_ARITHMETIC_OPS {
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[Self::wire_ith_addend(i)];
for i in 0..self.num_ops {
let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[self.wire_ith_addend(i)];
let computed_output = multiplicand_0 * multiplicand_1 + addend;
let output_low = vars.local_wires[Self::wire_ith_output_low_half(i)];
let output_high = vars.local_wires[Self::wire_ith_output_high_half(i)];
let output_low = vars.local_wires[self.wire_ith_output_low_half(i)];
let output_high = vars.local_wires[self.wire_ith_output_high_half(i)];
let base = F::Extension::from_canonical_u64(1 << 32u64);
let combined_output = output_high * base + output_low;
@ -86,7 +98,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
let midpoint = Self::num_limbs() / 2;
let base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits());
for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)];
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits();
let product = (0..max_limb)
.map(|x| this_limb - F::Extension::from_canonical_usize(x))
@ -108,15 +120,15 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints());
for i in 0..NUM_U32_ARITHMETIC_OPS {
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[Self::wire_ith_addend(i)];
for i in 0..self.num_ops {
let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[self.wire_ith_addend(i)];
let computed_output = multiplicand_0 * multiplicand_1 + addend;
let output_low = vars.local_wires[Self::wire_ith_output_low_half(i)];
let output_high = vars.local_wires[Self::wire_ith_output_high_half(i)];
let output_low = vars.local_wires[self.wire_ith_output_low_half(i)];
let output_high = vars.local_wires[self.wire_ith_output_high_half(i)];
let base = F::from_canonical_u64(1 << 32u64);
let combined_output = output_high * base + output_low;
@ -128,7 +140,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
let midpoint = Self::num_limbs() / 2;
let base = F::from_canonical_u64(1u64 << Self::limb_bits());
for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)];
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits();
let product = (0..max_limb)
.map(|x| this_limb - F::from_canonical_usize(x))
@ -155,15 +167,15 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
) -> Vec<ExtensionTarget<D>> {
let mut constraints = Vec::with_capacity(self.num_constraints());
for i in 0..NUM_U32_ARITHMETIC_OPS {
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[Self::wire_ith_addend(i)];
for i in 0..self.num_ops {
let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)];
let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)];
let addend = vars.local_wires[self.wire_ith_addend(i)];
let computed_output = builder.mul_add_extension(multiplicand_0, multiplicand_1, addend);
let output_low = vars.local_wires[Self::wire_ith_output_low_half(i)];
let output_high = vars.local_wires[Self::wire_ith_output_high_half(i)];
let output_low = vars.local_wires[self.wire_ith_output_low_half(i)];
let output_high = vars.local_wires[self.wire_ith_output_high_half(i)];
let base: F::Extension = F::from_canonical_u64(1 << 32u64).into();
let base_target = builder.constant_extension(base);
@ -177,7 +189,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
let base = builder
.constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits()));
for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)];
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits();
let mut product = builder.one_extension();
@ -210,10 +222,11 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
gate_index: usize,
_local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
(0..NUM_U32_ARITHMETIC_OPS)
(0..self.num_ops)
.map(|i| {
let g: Box<dyn WitnessGenerator<F>> = Box::new(
U32ArithmeticGenerator {
gate: *self,
gate_index,
i,
_phantom: PhantomData,
@ -226,7 +239,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
}
fn num_wires(&self) -> usize {
NUM_U32_ARITHMETIC_OPS * (5 + Self::num_limbs())
self.num_ops * (5 + Self::num_limbs())
}
fn num_constants(&self) -> usize {
@ -238,12 +251,13 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
}
fn num_constraints(&self) -> usize {
NUM_U32_ARITHMETIC_OPS * (3 + Self::num_limbs())
self.num_ops * (3 + Self::num_limbs())
}
}
#[derive(Clone, Debug)]
struct U32ArithmeticGenerator<F: Extendable<D>, const D: usize> {
gate: U32ArithmeticGate<F, D>,
gate_index: usize,
i: usize,
_phantom: PhantomData<F>,
@ -253,17 +267,11 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for U32ArithmeticGener
fn dependencies(&self) -> Vec<Target> {
let local_target = |input| Target::wire(self.gate_index, input);
let mut deps = Vec::with_capacity(3);
deps.push(local_target(
U32ArithmeticGate::<F, D>::wire_ith_multiplicand_0(self.i),
));
deps.push(local_target(
U32ArithmeticGate::<F, D>::wire_ith_multiplicand_1(self.i),
));
deps.push(local_target(U32ArithmeticGate::<F, D>::wire_ith_addend(
self.i,
)));
deps
vec![
local_target(self.gate.wire_ith_multiplicand_0(self.i)),
local_target(self.gate.wire_ith_multiplicand_1(self.i)),
local_target(self.gate.wire_ith_addend(self.i)),
]
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
@ -274,11 +282,9 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for U32ArithmeticGener
let get_local_wire = |input| witness.get_wire(local_wire(input));
let multiplicand_0 =
get_local_wire(U32ArithmeticGate::<F, D>::wire_ith_multiplicand_0(self.i));
let multiplicand_1 =
get_local_wire(U32ArithmeticGate::<F, D>::wire_ith_multiplicand_1(self.i));
let addend = get_local_wire(U32ArithmeticGate::<F, D>::wire_ith_addend(self.i));
let multiplicand_0 = get_local_wire(self.gate.wire_ith_multiplicand_0(self.i));
let multiplicand_1 = get_local_wire(self.gate.wire_ith_multiplicand_1(self.i));
let addend = get_local_wire(self.gate.wire_ith_addend(self.i));
let output = multiplicand_0 * multiplicand_1 + addend;
let mut output_u64 = output.to_canonical_u64();
@ -289,34 +295,25 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for U32ArithmeticGener
let output_high = F::from_canonical_u64(output_high_u64);
let output_low = F::from_canonical_u64(output_low_u64);
let output_high_wire =
local_wire(U32ArithmeticGate::<F, D>::wire_ith_output_high_half(self.i));
let output_low_wire =
local_wire(U32ArithmeticGate::<F, D>::wire_ith_output_low_half(self.i));
let output_high_wire = local_wire(self.gate.wire_ith_output_high_half(self.i));
let output_low_wire = local_wire(self.gate.wire_ith_output_low_half(self.i));
out_buffer.set_wire(output_high_wire, output_high);
out_buffer.set_wire(output_low_wire, output_low);
let num_limbs = U32ArithmeticGate::<F, D>::num_limbs();
let limb_base = 1 << U32ArithmeticGate::<F, D>::limb_bits();
let output_limbs_u64: Vec<_> = unfold((), move |_| {
let output_limbs_u64 = unfold((), move |_| {
let ret = output_u64 % limb_base;
output_u64 /= limb_base;
Some(ret)
})
.take(num_limbs)
.collect();
let output_limbs_f: Vec<_> = output_limbs_u64
.iter()
.cloned()
.map(F::from_canonical_u64)
.collect();
.take(num_limbs);
let output_limbs_f = output_limbs_u64.map(F::from_canonical_u64);
for j in 0..num_limbs {
let wire = local_wire(U32ArithmeticGate::<F, D>::wire_ith_output_jth_limb(
self.i, j,
));
out_buffer.set_wire(wire, output_limbs_f[j]);
for (j, output_limb) in output_limbs_f.enumerate() {
let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j));
out_buffer.set_wire(wire, output_limb);
}
}
}
@ -330,7 +327,7 @@ mod tests {
use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS};
use crate::gates::arithmetic_u32::U32ArithmeticGate;
use crate::gates::gate::Gate;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::hash::hash_types::HashOut;
@ -340,16 +337,15 @@ mod tests {
#[test]
fn low_degree() {
test_low_degree::<GoldilocksField, _, 4>(U32ArithmeticGate::<GoldilocksField, 4> {
num_ops: 3,
_phantom: PhantomData,
})
}
#[test]
fn eval_fns() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
test_eval_fns::<F, C, _, D>(U32ArithmeticGate::<F, D> {
test_eval_fns::<GoldilocksField, _, 4>(U32ArithmeticGate::<GoldilocksField, 4> {
num_ops: 3,
_phantom: PhantomData,
})
}
@ -360,6 +356,7 @@ mod tests {
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
type FF = <C as GenericConfig<D>>::FE;
const NUM_U32_ARITHMETIC_OPS: usize = 3;
fn get_wires(
multiplicands_0: Vec<u64>,
@ -387,8 +384,7 @@ mod tests {
output /= limb_base;
}
let mut output_limbs_f: Vec<_> = output_limbs
.iter()
.cloned()
.into_iter()
.map(F::from_canonical_u64)
.collect();
@ -418,6 +414,7 @@ mod tests {
.collect();
let gate = U32ArithmeticGate::<F, D> {
num_ops: NUM_U32_ARITHMETIC_OPS,
_phantom: PhantomData,
};

607
src/gates/assert_le.rs Normal file
View File

@ -0,0 +1,607 @@
use std::marker::PhantomData;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, PrimeField, RichField};
use crate::gates::gate::Gate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive};
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::util::{bits_u64, ceil_div_usize};
// TODO: replace/merge this gate with `ComparisonGate`.
/// A gate for checking that one value is less than or equal to another.
#[derive(Clone, Debug)]
pub struct AssertLessThanGate<F: PrimeField + Extendable<D>, const D: usize> {
pub(crate) num_bits: usize,
pub(crate) num_chunks: usize,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize> AssertLessThanGate<F, D> {
pub fn new(num_bits: usize, num_chunks: usize) -> Self {
debug_assert!(num_bits < bits_u64(F::ORDER));
Self {
num_bits,
num_chunks,
_phantom: PhantomData,
}
}
pub fn chunk_bits(&self) -> usize {
ceil_div_usize(self.num_bits, self.num_chunks)
}
pub fn wire_first_input(&self) -> usize {
0
}
pub fn wire_second_input(&self) -> usize {
1
}
pub fn wire_most_significant_diff(&self) -> usize {
2
}
pub fn wire_first_chunk_val(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks);
3 + chunk
}
pub fn wire_second_chunk_val(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks);
3 + self.num_chunks + chunk
}
pub fn wire_equality_dummy(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks);
3 + 2 * self.num_chunks + chunk
}
pub fn wire_chunks_equal(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks);
3 + 3 * self.num_chunks + chunk
}
pub fn wire_intermediate_value(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks);
3 + 4 * self.num_chunks + chunk
}
}
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for AssertLessThanGate<F, D> {
fn id(&self) -> String {
format!("{:?}<D={}>", self, D)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let mut constraints = Vec::with_capacity(self.num_constraints());
let first_input = vars.local_wires[self.wire_first_input()];
let second_input = vars.local_wires[self.wire_second_input()];
// Get chunks and assert that they match
let first_chunks: Vec<F::Extension> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_first_chunk_val(i)])
.collect();
let second_chunks: Vec<F::Extension> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_second_chunk_val(i)])
.collect();
let first_chunks_combined = reduce_with_powers(
&first_chunks,
F::Extension::from_canonical_usize(1 << self.chunk_bits()),
);
let second_chunks_combined = reduce_with_powers(
&second_chunks,
F::Extension::from_canonical_usize(1 << self.chunk_bits()),
);
constraints.push(first_chunks_combined - first_input);
constraints.push(second_chunks_combined - second_input);
let chunk_size = 1 << self.chunk_bits();
let mut most_significant_diff_so_far = F::Extension::ZERO;
for i in 0..self.num_chunks {
// Range-check the chunks to be less than `chunk_size`.
let first_product = (0..chunk_size)
.map(|x| first_chunks[i] - F::Extension::from_canonical_usize(x))
.product();
let second_product = (0..chunk_size)
.map(|x| second_chunks[i] - F::Extension::from_canonical_usize(x))
.product();
constraints.push(first_product);
constraints.push(second_product);
let difference = second_chunks[i] - first_chunks[i];
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
// Two constraints to assert that `chunks_equal` is valid.
constraints.push(difference * equality_dummy - (F::Extension::ONE - chunks_equal));
constraints.push(chunks_equal * difference);
// Update `most_significant_diff_so_far`.
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far);
most_significant_diff_so_far =
intermediate_value + (F::Extension::ONE - chunks_equal) * difference;
}
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
constraints.push(most_significant_diff - most_significant_diff_so_far);
// Range check `most_significant_diff` to be less than `chunk_size`.
let product = (0..chunk_size)
.map(|x| most_significant_diff - F::Extension::from_canonical_usize(x))
.product();
constraints.push(product);
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints());
let first_input = vars.local_wires[self.wire_first_input()];
let second_input = vars.local_wires[self.wire_second_input()];
// Get chunks and assert that they match
let first_chunks: Vec<F> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_first_chunk_val(i)])
.collect();
let second_chunks: Vec<F> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_second_chunk_val(i)])
.collect();
let first_chunks_combined = reduce_with_powers(
&first_chunks,
F::from_canonical_usize(1 << self.chunk_bits()),
);
let second_chunks_combined = reduce_with_powers(
&second_chunks,
F::from_canonical_usize(1 << self.chunk_bits()),
);
constraints.push(first_chunks_combined - first_input);
constraints.push(second_chunks_combined - second_input);
let chunk_size = 1 << self.chunk_bits();
let mut most_significant_diff_so_far = F::ZERO;
for i in 0..self.num_chunks {
// Range-check the chunks to be less than `chunk_size`.
let first_product = (0..chunk_size)
.map(|x| first_chunks[i] - F::from_canonical_usize(x))
.product();
let second_product = (0..chunk_size)
.map(|x| second_chunks[i] - F::from_canonical_usize(x))
.product();
constraints.push(first_product);
constraints.push(second_product);
let difference = second_chunks[i] - first_chunks[i];
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
// Two constraints to assert that `chunks_equal` is valid.
constraints.push(difference * equality_dummy - (F::ONE - chunks_equal));
constraints.push(chunks_equal * difference);
// Update `most_significant_diff_so_far`.
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far);
most_significant_diff_so_far =
intermediate_value + (F::ONE - chunks_equal) * difference;
}
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
constraints.push(most_significant_diff - most_significant_diff_so_far);
// Range check `most_significant_diff` to be less than `chunk_size`.
let product = (0..chunk_size)
.map(|x| most_significant_diff - F::from_canonical_usize(x))
.product();
constraints.push(product);
constraints
}
fn eval_unfiltered_recursively(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let mut constraints = Vec::with_capacity(self.num_constraints());
let first_input = vars.local_wires[self.wire_first_input()];
let second_input = vars.local_wires[self.wire_second_input()];
// Get chunks and assert that they match
let first_chunks: Vec<ExtensionTarget<D>> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_first_chunk_val(i)])
.collect();
let second_chunks: Vec<ExtensionTarget<D>> = (0..self.num_chunks)
.map(|i| vars.local_wires[self.wire_second_chunk_val(i)])
.collect();
let chunk_base = builder.constant(F::from_canonical_usize(1 << self.chunk_bits()));
let first_chunks_combined =
reduce_with_powers_ext_recursive(builder, &first_chunks, chunk_base);
let second_chunks_combined =
reduce_with_powers_ext_recursive(builder, &second_chunks, chunk_base);
constraints.push(builder.sub_extension(first_chunks_combined, first_input));
constraints.push(builder.sub_extension(second_chunks_combined, second_input));
let chunk_size = 1 << self.chunk_bits();
let mut most_significant_diff_so_far = builder.zero_extension();
let one = builder.one_extension();
// Find the chosen chunk.
for i in 0..self.num_chunks {
// Range-check the chunks to be less than `chunk_size`.
let mut first_product = one;
let mut second_product = one;
for x in 0..chunk_size {
let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x));
let first_diff = builder.sub_extension(first_chunks[i], x_f);
let second_diff = builder.sub_extension(second_chunks[i], x_f);
first_product = builder.mul_extension(first_product, first_diff);
second_product = builder.mul_extension(second_product, second_diff);
}
constraints.push(first_product);
constraints.push(second_product);
let difference = builder.sub_extension(second_chunks[i], first_chunks[i]);
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
// Two constraints to assert that `chunks_equal` is valid.
let diff_times_equal = builder.mul_extension(difference, equality_dummy);
let not_equal = builder.sub_extension(one, chunks_equal);
constraints.push(builder.sub_extension(diff_times_equal, not_equal));
constraints.push(builder.mul_extension(chunks_equal, difference));
// Update `most_significant_diff_so_far`.
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
let old_diff = builder.mul_extension(chunks_equal, most_significant_diff_so_far);
constraints.push(builder.sub_extension(intermediate_value, old_diff));
let not_equal = builder.sub_extension(one, chunks_equal);
let new_diff = builder.mul_extension(not_equal, difference);
most_significant_diff_so_far = builder.add_extension(intermediate_value, new_diff);
}
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
constraints
.push(builder.sub_extension(most_significant_diff, most_significant_diff_so_far));
// Range check `most_significant_diff` to be less than `chunk_size`.
let mut product = builder.one_extension();
for x in 0..chunk_size {
let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x));
let diff = builder.sub_extension(most_significant_diff, x_f);
product = builder.mul_extension(product, diff);
}
constraints.push(product);
constraints
}
fn generators(
&self,
gate_index: usize,
_local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
let gen = AssertLessThanGenerator::<F, D> {
gate_index,
gate: self.clone(),
};
vec![Box::new(gen.adapter())]
}
fn num_wires(&self) -> usize {
self.wire_intermediate_value(self.num_chunks - 1) + 1
}
fn num_constants(&self) -> usize {
0
}
fn degree(&self) -> usize {
1 << self.chunk_bits()
}
fn num_constraints(&self) -> usize {
4 + 5 * self.num_chunks
}
}
#[derive(Debug)]
struct AssertLessThanGenerator<F: RichField + Extendable<D>, const D: usize> {
gate_index: usize,
gate: AssertLessThanGate<F, D>,
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
for AssertLessThanGenerator<F, D>
{
fn dependencies(&self) -> Vec<Target> {
let local_target = |input| Target::wire(self.gate_index, input);
vec![
local_target(self.gate.wire_first_input()),
local_target(self.gate.wire_second_input()),
]
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let local_wire = |input| Wire {
gate: self.gate_index,
input,
};
let get_local_wire = |input| witness.get_wire(local_wire(input));
let first_input = get_local_wire(self.gate.wire_first_input());
let second_input = get_local_wire(self.gate.wire_second_input());
let first_input_u64 = first_input.to_canonical_u64();
let second_input_u64 = second_input.to_canonical_u64();
debug_assert!(first_input_u64 < second_input_u64);
let chunk_size = 1 << self.gate.chunk_bits();
let first_input_chunks: Vec<F> = (0..self.gate.num_chunks)
.scan(first_input_u64, |acc, _| {
let tmp = *acc % chunk_size;
*acc /= chunk_size;
Some(F::from_canonical_u64(tmp))
})
.collect();
let second_input_chunks: Vec<F> = (0..self.gate.num_chunks)
.scan(second_input_u64, |acc, _| {
let tmp = *acc % chunk_size;
*acc /= chunk_size;
Some(F::from_canonical_u64(tmp))
})
.collect();
let chunks_equal: Vec<F> = (0..self.gate.num_chunks)
.map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i]))
.collect();
let equality_dummies: Vec<F> = first_input_chunks
.iter()
.zip(second_input_chunks.iter())
.map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) })
.collect();
let mut most_significant_diff_so_far = F::ZERO;
let mut intermediate_values = Vec::new();
for i in 0..self.gate.num_chunks {
if first_input_chunks[i] != second_input_chunks[i] {
most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i];
intermediate_values.push(F::ZERO);
} else {
intermediate_values.push(most_significant_diff_so_far);
}
}
let most_significant_diff = most_significant_diff_so_far;
out_buffer.set_wire(
local_wire(self.gate.wire_most_significant_diff()),
most_significant_diff,
);
for i in 0..self.gate.num_chunks {
out_buffer.set_wire(
local_wire(self.gate.wire_first_chunk_val(i)),
first_input_chunks[i],
);
out_buffer.set_wire(
local_wire(self.gate.wire_second_chunk_val(i)),
second_input_chunks[i],
);
out_buffer.set_wire(
local_wire(self.gate.wire_equality_dummy(i)),
equality_dummies[i],
);
out_buffer.set_wire(local_wire(self.gate.wire_chunks_equal(i)), chunks_equal[i]);
out_buffer.set_wire(
local_wire(self.gate.wire_intermediate_value(i)),
intermediate_values[i],
);
}
}
}
#[cfg(test)]
mod tests {
use std::marker::PhantomData;
use anyhow::Result;
use rand::Rng;
use crate::field::extension_field::quartic::QuarticExtension;
use crate::field::field_types::{Field, PrimeField};
use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::assert_le::AssertLessThanGate;
use crate::gates::gate::Gate;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::hash::hash_types::HashOut;
use crate::plonk::vars::EvaluationVars;
#[test]
fn wire_indices() {
type AG = AssertLessThanGate<GoldilocksField, 4>;
let num_bits = 40;
let num_chunks = 5;
let gate = AG {
num_bits,
num_chunks,
_phantom: PhantomData,
};
assert_eq!(gate.wire_first_input(), 0);
assert_eq!(gate.wire_second_input(), 1);
assert_eq!(gate.wire_most_significant_diff(), 2);
assert_eq!(gate.wire_first_chunk_val(0), 3);
assert_eq!(gate.wire_first_chunk_val(4), 7);
assert_eq!(gate.wire_second_chunk_val(0), 8);
assert_eq!(gate.wire_second_chunk_val(4), 12);
assert_eq!(gate.wire_equality_dummy(0), 13);
assert_eq!(gate.wire_equality_dummy(4), 17);
assert_eq!(gate.wire_chunks_equal(0), 18);
assert_eq!(gate.wire_chunks_equal(4), 22);
assert_eq!(gate.wire_intermediate_value(0), 23);
assert_eq!(gate.wire_intermediate_value(4), 27);
}
#[test]
fn low_degree() {
let num_bits = 20;
let num_chunks = 4;
test_low_degree::<GoldilocksField, _, 4>(AssertLessThanGate::<_, 4>::new(
num_bits, num_chunks,
))
}
#[test]
fn eval_fns() -> Result<()> {
let num_bits = 20;
let num_chunks = 4;
test_eval_fns::<GoldilocksField, _, 4>(AssertLessThanGate::<_, 4>::new(
num_bits, num_chunks,
))
}
#[test]
fn test_gate_constraint() {
type F = GoldilocksField;
type FF = QuarticExtension<GoldilocksField>;
const D: usize = 4;
let num_bits = 40;
let num_chunks = 5;
let chunk_bits = num_bits / num_chunks;
// Returns the local wires for an AssertLessThanGate given the two inputs.
let get_wires = |first_input: F, second_input: F| -> Vec<FF> {
let mut v = Vec::new();
let first_input_u64 = first_input.to_canonical_u64();
let second_input_u64 = second_input.to_canonical_u64();
let chunk_size = 1 << chunk_bits;
let mut first_input_chunks: Vec<F> = (0..num_chunks)
.scan(first_input_u64, |acc, _| {
let tmp = *acc % chunk_size;
*acc /= chunk_size;
Some(F::from_canonical_u64(tmp))
})
.collect();
let mut second_input_chunks: Vec<F> = (0..num_chunks)
.scan(second_input_u64, |acc, _| {
let tmp = *acc % chunk_size;
*acc /= chunk_size;
Some(F::from_canonical_u64(tmp))
})
.collect();
let mut chunks_equal: Vec<F> = (0..num_chunks)
.map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i]))
.collect();
let mut equality_dummies: Vec<F> = first_input_chunks
.iter()
.zip(second_input_chunks.iter())
.map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) })
.collect();
let mut most_significant_diff_so_far = F::ZERO;
let mut intermediate_values = Vec::new();
for i in 0..num_chunks {
if first_input_chunks[i] != second_input_chunks[i] {
most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i];
intermediate_values.push(F::ZERO);
} else {
intermediate_values.push(most_significant_diff_so_far);
}
}
let most_significant_diff = most_significant_diff_so_far;
v.push(first_input);
v.push(second_input);
v.push(most_significant_diff);
v.append(&mut first_input_chunks);
v.append(&mut second_input_chunks);
v.append(&mut equality_dummies);
v.append(&mut chunks_equal);
v.append(&mut intermediate_values);
v.iter().map(|&x| x.into()).collect::<Vec<_>>()
};
let mut rng = rand::thread_rng();
let max: u64 = 1 << (num_bits - 1);
let first_input_u64 = rng.gen_range(0..max);
let second_input_u64 = {
let mut val = rng.gen_range(0..max);
while val < first_input_u64 {
val = rng.gen_range(0..max);
}
val
};
let first_input = F::from_canonical_u64(first_input_u64);
let second_input = F::from_canonical_u64(second_input_u64);
let less_than_gate = AssertLessThanGate::<F, D> {
num_bits,
num_chunks,
_phantom: PhantomData,
};
let less_than_vars = EvaluationVars {
local_constants: &[],
local_wires: &get_wires(first_input, second_input),
public_inputs_hash: &HashOut::rand(),
};
assert!(
less_than_gate
.eval_unfiltered(less_than_vars)
.iter()
.all(|x| x.is_zero()),
"Gate constraints are not satisfied."
);
let equal_gate = AssertLessThanGate::<F, D> {
num_bits,
num_chunks,
_phantom: PhantomData,
};
let equal_vars = EvaluationVars {
local_constants: &[],
local_wires: &get_wires(first_input, first_input),
public_inputs_hash: &HashOut::rand(),
};
assert!(
equal_gate
.eval_unfiltered(equal_vars)
.iter()
.all(|x| x.is_zero()),
"Gate constraints are not satisfied."
);
}
}

View File

@ -24,8 +24,7 @@ impl<const B: usize> BaseSumGate<B> {
}
pub fn new_from_config<F: PrimeField>(config: &CircuitConfig) -> Self {
let num_limbs = ((F::ORDER as f64).log(B as f64).floor() as usize)
.min(config.num_routed_wires - Self::START_LIMBS);
let num_limbs = F::BITS.min(config.num_routed_wires - Self::START_LIMBS);
Self::new(num_limbs)
}

View File

@ -43,33 +43,42 @@ impl<F: Extendable<D>, const D: usize> ComparisonGate<F, D> {
1
}
pub fn wire_most_significant_diff(&self) -> usize {
pub fn wire_result_bool(&self) -> usize {
2
}
pub fn wire_most_significant_diff(&self) -> usize {
3
}
pub fn wire_first_chunk_val(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks);
3 + chunk
4 + chunk
}
pub fn wire_second_chunk_val(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks);
3 + self.num_chunks + chunk
4 + self.num_chunks + chunk
}
pub fn wire_equality_dummy(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks);
3 + 2 * self.num_chunks + chunk
4 + 2 * self.num_chunks + chunk
}
pub fn wire_chunks_equal(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks);
3 + 3 * self.num_chunks + chunk
4 + 3 * self.num_chunks + chunk
}
pub fn wire_intermediate_value(&self, chunk: usize) -> usize {
debug_assert!(chunk < self.num_chunks);
3 + 4 * self.num_chunks + chunk
4 + 4 * self.num_chunks + chunk
}
/// The `bit_index`th bit of 2^n - 1 + most_significant_diff.
pub fn wire_most_significant_diff_bit(&self, bit_index: usize) -> usize {
4 + 5 * self.num_chunks + bit_index
}
}
@ -110,10 +119,10 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
for i in 0..self.num_chunks {
// Range-check the chunks to be less than `chunk_size`.
let first_product = (0..chunk_size)
let first_product: F::Extension = (0..chunk_size)
.map(|x| first_chunks[i] - F::Extension::from_canonical_usize(x))
.product();
let second_product = (0..chunk_size)
let second_product: F::Extension = (0..chunk_size)
.map(|x| second_chunks[i] - F::Extension::from_canonical_usize(x))
.product();
constraints.push(first_product);
@ -137,11 +146,22 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
constraints.push(most_significant_diff - most_significant_diff_so_far);
// Range check `most_significant_diff` to be less than `chunk_size`.
let product = (0..chunk_size)
.map(|x| most_significant_diff - F::Extension::from_canonical_usize(x))
.product();
constraints.push(product);
let most_significant_diff_bits: Vec<F::Extension> = (0..self.chunk_bits() + 1)
.map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)])
.collect();
// Range-check the bits.
for &bit in &most_significant_diff_bits {
constraints.push(bit * (F::Extension::ONE - bit));
}
let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::Extension::TWO);
let two_n = F::Extension::from_canonical_u64(1 << self.chunk_bits());
constraints.push((two_n + most_significant_diff) - bits_combined);
// Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1.
let result_bool = vars.local_wires[self.wire_result_bool()];
constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]);
constraints
}
@ -178,10 +198,10 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
for i in 0..self.num_chunks {
// Range-check the chunks to be less than `chunk_size`.
let first_product = (0..chunk_size)
let first_product: F = (0..chunk_size)
.map(|x| first_chunks[i] - F::from_canonical_usize(x))
.product();
let second_product = (0..chunk_size)
let second_product: F = (0..chunk_size)
.map(|x| second_chunks[i] - F::from_canonical_usize(x))
.product();
constraints.push(first_product);
@ -205,11 +225,22 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
constraints.push(most_significant_diff - most_significant_diff_so_far);
// Range check `most_significant_diff` to be less than `chunk_size`.
let product = (0..chunk_size)
.map(|x| most_significant_diff - F::from_canonical_usize(x))
.product();
constraints.push(product);
let most_significant_diff_bits: Vec<F> = (0..self.chunk_bits() + 1)
.map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)])
.collect();
// Range-check the bits.
for &bit in &most_significant_diff_bits {
constraints.push(bit * (F::ONE - bit));
}
let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO);
let two_n = F::from_canonical_u64(1 << self.chunk_bits());
constraints.push((two_n + most_significant_diff) - bits_combined);
// Iff first <= second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1.
let result_bool = vars.local_wires[self.wire_result_bool()];
constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]);
constraints
}
@ -285,14 +316,29 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
constraints
.push(builder.sub_extension(most_significant_diff, most_significant_diff_so_far));
// Range check `most_significant_diff` to be less than `chunk_size`.
let mut product = builder.one_extension();
for x in 0..chunk_size {
let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x));
let diff = builder.sub_extension(most_significant_diff, x_f);
product = builder.mul_extension(product, diff);
let most_significant_diff_bits: Vec<ExtensionTarget<D>> = (0..self.chunk_bits() + 1)
.map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)])
.collect();
// Range-check the bits.
for &this_bit in &most_significant_diff_bits {
let inverse = builder.sub_extension(one, this_bit);
constraints.push(builder.mul_extension(this_bit, inverse));
}
constraints.push(product);
let two = builder.two();
let bits_combined =
reduce_with_powers_ext_recursive(builder, &most_significant_diff_bits, two);
let two_n =
builder.constant_extension(F::Extension::from_canonical_u64(1 << self.chunk_bits()));
let sum = builder.add_extension(two_n, most_significant_diff);
constraints.push(builder.sub_extension(sum, bits_combined));
// Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1.
let result_bool = vars.local_wires[self.wire_result_bool()];
constraints.push(
builder.sub_extension(result_bool, most_significant_diff_bits[self.chunk_bits()]),
);
constraints
}
@ -310,7 +356,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
}
fn num_wires(&self) -> usize {
self.wire_intermediate_value(self.num_chunks - 1) + 1
4 + 5 * self.num_chunks + (self.chunk_bits() + 1)
}
fn num_constants(&self) -> usize {
@ -322,7 +368,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
}
fn num_constraints(&self) -> usize {
4 + 5 * self.num_chunks
6 + 5 * self.num_chunks + self.chunk_bits()
}
}
@ -336,10 +382,10 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ComparisonGenerato
fn dependencies(&self) -> Vec<Target> {
let local_target = |input| Target::wire(self.gate_index, input);
let mut deps = Vec::new();
deps.push(local_target(self.gate.wire_first_input()));
deps.push(local_target(self.gate.wire_second_input()));
deps
vec![
local_target(self.gate.wire_first_input()),
local_target(self.gate.wire_second_input()),
]
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
@ -356,7 +402,7 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ComparisonGenerato
let first_input_u64 = first_input.to_canonical_u64();
let second_input_u64 = second_input.to_canonical_u64();
debug_assert!(first_input_u64 < second_input_u64);
let result = F::from_canonical_usize((first_input_u64 <= second_input_u64) as usize);
let chunk_size = 1 << self.gate.chunk_bits();
let first_input_chunks: Vec<F> = (0..self.gate.num_chunks)
@ -395,6 +441,22 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ComparisonGenerato
}
let most_significant_diff = most_significant_diff_so_far;
let two_n = F::from_canonical_usize(1 << self.gate.chunk_bits());
let two_n_plus_msd = (two_n + most_significant_diff).to_canonical_u64();
let msd_bits_u64: Vec<u64> = (0..self.gate.chunk_bits() + 1)
.scan(two_n_plus_msd, |acc, _| {
let tmp = *acc % 2;
*acc /= 2;
Some(tmp)
})
.collect();
let msd_bits: Vec<F> = msd_bits_u64
.iter()
.map(|x| F::from_canonical_u64(*x))
.collect();
out_buffer.set_wire(local_wire(self.gate.wire_result_bool()), result);
out_buffer.set_wire(
local_wire(self.gate.wire_most_significant_diff()),
most_significant_diff,
@ -418,6 +480,12 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ComparisonGenerato
intermediate_values[i],
);
}
for i in 0..self.gate.chunk_bits() + 1 {
out_buffer.set_wire(
local_wire(self.gate.wire_most_significant_diff_bit(i)),
msd_bits[i],
);
}
}
}
@ -451,17 +519,20 @@ mod tests {
assert_eq!(gate.wire_first_input(), 0);
assert_eq!(gate.wire_second_input(), 1);
assert_eq!(gate.wire_most_significant_diff(), 2);
assert_eq!(gate.wire_first_chunk_val(0), 3);
assert_eq!(gate.wire_first_chunk_val(4), 7);
assert_eq!(gate.wire_second_chunk_val(0), 8);
assert_eq!(gate.wire_second_chunk_val(4), 12);
assert_eq!(gate.wire_equality_dummy(0), 13);
assert_eq!(gate.wire_equality_dummy(4), 17);
assert_eq!(gate.wire_chunks_equal(0), 18);
assert_eq!(gate.wire_chunks_equal(4), 22);
assert_eq!(gate.wire_intermediate_value(0), 23);
assert_eq!(gate.wire_intermediate_value(4), 27);
assert_eq!(gate.wire_result_bool(), 2);
assert_eq!(gate.wire_most_significant_diff(), 3);
assert_eq!(gate.wire_first_chunk_val(0), 4);
assert_eq!(gate.wire_first_chunk_val(4), 8);
assert_eq!(gate.wire_second_chunk_val(0), 9);
assert_eq!(gate.wire_second_chunk_val(4), 13);
assert_eq!(gate.wire_equality_dummy(0), 14);
assert_eq!(gate.wire_equality_dummy(4), 18);
assert_eq!(gate.wire_chunks_equal(0), 19);
assert_eq!(gate.wire_chunks_equal(4), 23);
assert_eq!(gate.wire_intermediate_value(0), 24);
assert_eq!(gate.wire_intermediate_value(4), 28);
assert_eq!(gate.wire_most_significant_diff_bit(0), 29);
assert_eq!(gate.wire_most_significant_diff_bit(8), 37);
}
#[test]
@ -501,6 +572,8 @@ mod tests {
let first_input_u64 = first_input.to_canonical_u64();
let second_input_u64 = second_input.to_canonical_u64();
let result_bool = F::from_bool(first_input_u64 <= second_input_u64);
let chunk_size = 1 << chunk_bits;
let mut first_input_chunks: Vec<F> = (0..num_chunks)
.scan(first_input_u64, |acc, _| {
@ -538,20 +611,32 @@ mod tests {
}
let most_significant_diff = most_significant_diff_so_far;
let two_n_plus_msd =
(1 << chunk_bits) as u64 + most_significant_diff.to_canonical_u64();
let mut msd_bits: Vec<F> = (0..chunk_bits + 1)
.scan(two_n_plus_msd, |acc, _| {
let tmp = *acc % 2;
*acc /= 2;
Some(F::from_canonical_u64(tmp))
})
.collect();
v.push(first_input);
v.push(second_input);
v.push(result_bool);
v.push(most_significant_diff);
v.append(&mut first_input_chunks);
v.append(&mut second_input_chunks);
v.append(&mut equality_dummies);
v.append(&mut chunks_equal);
v.append(&mut intermediate_values);
v.append(&mut msd_bits);
v.iter().map(|&x| x.into()).collect::<Vec<_>>()
};
let mut rng = rand::thread_rng();
let max: u64 = 1 << num_bits - 1;
let max: u64 = 1 << (num_bits - 1);
let first_input_u64 = rng.gen_range(0..max);
let second_input_u64 = {
let mut val = rng.gen_range(0..max);

View File

@ -337,9 +337,8 @@ mod tests {
.map(|b| F::from_canonical_u64(*b))
.collect();
let mut v = Vec::new();
v.push(base);
v.extend(power_bits_f.clone());
let mut v = vec![base];
v.extend(power_bits_f);
let mut intermediate_values = Vec::new();
let mut current_intermediate_value = F::ONE;

View File

@ -10,7 +10,7 @@ use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::config::GenericConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::plonk::verifier::verify;
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::util::{log2_ceil, transpose};
const WITNESS_SIZE: usize = 1 << 5;

View File

@ -1,4 +1,4 @@
use log::info;
use log::debug;
use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField;
@ -86,7 +86,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Tree<GateRef<F, D>> {
}
}
}
info!(
debug!(
"Found tree with max degree {} and {} constants wires in {:.4}s.",
best_degree,
best_num_constants,
@ -221,12 +221,17 @@ impl<F: RichField + Extendable<D>, const D: usize> Tree<GateRef<F, D>> {
#[cfg(test)]
mod tests {
use log::info;
use super::*;
use crate::field::goldilocks_field::GoldilocksField;
use crate::gadgets::interpolation::InterpolationGate;
use crate::gates::arithmetic_extension::ArithmeticExtensionGate;
use crate::gates::arithmetic::ArithmeticExtensionGate;
use crate::gates::base_sum::BaseSumGate;
use crate::gates::constant::ConstantGate;
use crate::gates::gmimc::GMiMCGate;
use crate::gates::interpolation::InterpolationGate;
use crate::gates::interpolation::HighDegreeInterpolationGate;
use crate::gates::noop::NoopGate;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
@ -243,7 +248,7 @@ mod tests {
GateRef::new(ArithmeticExtensionGate { num_ops: 4 }),
GateRef::new(BaseSumGate::<4>::new(4)),
GateRef::new(GMiMCGate::<F, D, 12>::new()),
GateRef::new(InterpolationGate::new(2)),
GateRef::new(HighDegreeInterpolationGate::new(2)),
];
let (tree, _, _) = Tree::from_gates(gates.clone());

View File

@ -318,8 +318,6 @@ impl<F: Extendable<D> + GMiMC<WIDTH>, const D: usize, const WIDTH: usize> Simple
#[cfg(test)]
mod tests {
use std::convert::TryInto;
use anyhow::Result;
use crate::field::field_types::Field;

View File

@ -1,4 +1,3 @@
use std::convert::TryInto;
use std::marker::PhantomData;
use std::ops::Range;
@ -252,8 +251,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F> for Insert
let local_targets = |inputs: Range<usize>| inputs.map(local_target);
let mut deps = Vec::new();
deps.push(local_target(self.gate.wires_insertion_index()));
let mut deps = vec![local_target(self.gate.wires_insertion_index())];
deps.extend(local_targets(self.gate.wires_element_to_insert()));
for i in 0..self.gate.vec_size {
deps.extend(local_targets(self.gate.wires_original_list_item(i)));
@ -292,7 +290,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F> for Insert
vec_size
);
let mut new_vec = orig_vec.clone();
let mut new_vec = orig_vec;
new_vec.insert(insertion_index, to_insert);
let mut equality_dummy_vals = Vec::new();
@ -377,14 +375,13 @@ mod tests {
fn get_wires(orig_vec: Vec<FF>, insertion_index: usize, element_to_insert: FF) -> Vec<FF> {
let vec_size = orig_vec.len();
let mut v = Vec::new();
v.push(F::from_canonical_usize(insertion_index));
let mut v = vec![F::from_canonical_usize(insertion_index)];
v.extend(element_to_insert.0);
for j in 0..vec_size {
v.extend(orig_vec[j].0);
}
let mut new_vec = orig_vec.clone();
let mut new_vec = orig_vec;
new_vec.insert(insertion_index, element_to_insert);
let mut equality_dummy_vals = Vec::new();
for i in 0..=vec_size {

View File

@ -1,4 +1,3 @@
use std::convert::TryInto;
use std::marker::PhantomData;
use std::ops::Range;
@ -6,6 +5,7 @@ use crate::field::extension_field::algebra::PolynomialCoeffsAlgebra;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::interpolation::interpolant;
use crate::gadgets::interpolation::InterpolationGate;
use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget;
use crate::gates::gate::Gate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
@ -14,19 +14,20 @@ use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::polynomial::polynomial::PolynomialCoeffs;
use crate::polynomial::PolynomialCoeffs;
/// Interpolates a polynomial, whose points are a (base field) coset of the multiplicative subgroup
/// with the given size, and whose values are extension field elements, given by input wires.
/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point.
#[derive(Clone, Debug)]
pub(crate) struct InterpolationGate<F: Extendable<D>, const D: usize> {
/// Interpolation gate with constraints of degree at most `1<<subgroup_bits`.
/// `eval_unfiltered_recursively` uses less gates than `LowDegreeInterpolationGate`.
#[derive(Copy, Clone, Debug)]
pub(crate) struct HighDegreeInterpolationGate<F: RichField + Extendable<D>, const D: usize> {
pub subgroup_bits: usize,
_phantom: PhantomData<F>,
}
impl<F: Extendable<D>, const D: usize> InterpolationGate<F, D> {
pub fn new(subgroup_bits: usize) -> Self {
impl<F: Extendable<D>, const D: usize> InterpolationGate<F, D>
for HighDegreeInterpolationGate<F, D>
{
fn new(subgroup_bits: usize) -> Self {
Self {
subgroup_bits,
_phantom: PhantomData,
@ -36,60 +37,9 @@ impl<F: Extendable<D>, const D: usize> InterpolationGate<F, D> {
fn num_points(&self) -> usize {
1 << self.subgroup_bits
}
}
/// Wire index of the coset shift.
pub fn wire_shift(&self) -> usize {
0
}
fn start_values(&self) -> usize {
1
}
/// Wire indices of the `i`th interpolant value.
pub fn wires_value(&self, i: usize) -> Range<usize> {
debug_assert!(i < self.num_points());
let start = self.start_values() + i * D;
start..start + D
}
fn start_evaluation_point(&self) -> usize {
self.start_values() + self.num_points() * D
}
/// Wire indices of the point to evaluate the interpolant at.
pub fn wires_evaluation_point(&self) -> Range<usize> {
let start = self.start_evaluation_point();
start..start + D
}
fn start_evaluation_value(&self) -> usize {
self.start_evaluation_point() + D
}
/// Wire indices of the interpolated value.
pub fn wires_evaluation_value(&self) -> Range<usize> {
let start = self.start_evaluation_value();
start..start + D
}
fn start_coeffs(&self) -> usize {
self.start_evaluation_value() + D
}
/// The number of routed wires required in the typical usage of this gate, where the points to
/// interpolate, the evaluation point, and the corresponding value are all routed.
pub(crate) fn num_routed_wires(&self) -> usize {
self.start_coeffs()
}
/// Wire indices of the interpolant's `i`th coefficient.
pub fn wires_coeff(&self, i: usize) -> Range<usize> {
debug_assert!(i < self.num_points());
let start = self.start_coeffs() + i * D;
start..start + D
}
impl<F: RichField + Extendable<D>, const D: usize> HighDegreeInterpolationGate<F, D> {
/// End of wire indices, exclusive.
fn end(&self) -> usize {
self.start_coeffs() + self.num_points() * D
@ -121,14 +71,16 @@ impl<F: Extendable<D>, const D: usize> InterpolationGate<F, D> {
g.powers()
.take(size)
.map(move |x| {
let subgroup_element = builder.constant(x.into());
let subgroup_element = builder.constant(x);
builder.scalar_mul_ext(subgroup_element, shift)
})
.collect()
}
}
impl<F: Extendable<D>, const D: usize> Gate<F, D> for InterpolationGate<F, D> {
impl<F: Extendable<D>, const D: usize> Gate<F, D>
for HighDegreeInterpolationGate<F, D>
{
fn id(&self) -> String {
format!("{:?}<D={}>", self, D)
}
@ -221,7 +173,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InterpolationGate<F, D> {
) -> Vec<Box<dyn WitnessGenerator<F>>> {
let gen = InterpolationGenerator::<F, D> {
gate_index,
gate: self.clone(),
gate: *self,
_phantom: PhantomData,
};
vec![Box::new(gen.adapter())]
@ -251,7 +203,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InterpolationGate<F, D> {
#[derive(Debug)]
struct InterpolationGenerator<F: Extendable<D>, const D: usize> {
gate_index: usize,
gate: InterpolationGate<F, D>,
gate: HighDegreeInterpolationGate<F, D>,
_phantom: PhantomData<F>,
}
@ -321,17 +273,18 @@ mod tests {
use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField;
use crate::gadgets::interpolation::InterpolationGate;
use crate::gates::gate::Gate;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::gates::interpolation::InterpolationGate;
use crate::gates::interpolation::HighDegreeInterpolationGate;
use crate::hash::hash_types::HashOut;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use crate::plonk::vars::EvaluationVars;
use crate::polynomial::polynomial::PolynomialCoeffs;
use crate::polynomial::PolynomialCoeffs;
#[test]
fn wire_indices() {
let gate = InterpolationGate::<GoldilocksField, 4> {
let gate = HighDegreeInterpolationGate::<GoldilocksField, 4> {
subgroup_bits: 1,
_phantom: PhantomData,
};
@ -350,7 +303,7 @@ mod tests {
#[test]
fn low_degree() {
test_low_degree::<GoldilocksField, _, 4>(InterpolationGate::new(2));
test_low_degree::<GoldilocksField, _, 4>(HighDegreeInterpolationGate::new(2));
}
#[test]
@ -358,7 +311,7 @@ mod tests {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
test_eval_fns::<F, C, _, D>(InterpolationGate::new(2))
test_eval_fns::<F, C, _, D>(HighDegreeInterpolationGate::new(2))
}
#[test]
@ -370,7 +323,7 @@ mod tests {
/// Returns the local wires for an interpolation gate for given coeffs, points and eval point.
fn get_wires(
gate: &InterpolationGate<F, D>,
gate: &HighDegreeInterpolationGate<F, D>,
shift: F,
coeffs: PolynomialCoeffs<FF>,
eval_point: FF,
@ -392,7 +345,7 @@ mod tests {
let shift = F::rand();
let coeffs = PolynomialCoeffs::new(vec![FF::rand(), FF::rand()]);
let eval_point = FF::rand();
let gate = InterpolationGate::<F, D>::new(1);
let gate = HighDegreeInterpolationGate::<F, D>::new(1);
let vars = EvaluationVars {
local_constants: &[],
local_wires: &get_wires(&gate, shift, coeffs, eval_point),

View File

@ -0,0 +1,459 @@
use std::marker::PhantomData;
use std::ops::Range;
use crate::field::extension_field::algebra::PolynomialCoeffsAlgebra;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::{Field, RichField};
use crate::field::interpolation::interpolant;
use crate::gadgets::interpolation::InterpolationGate;
use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget;
use crate::gates::gate::Gate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::polynomial::PolynomialCoeffs;
/// Interpolation gate with constraints of degree 2.
/// `eval_unfiltered_recursively` uses more gates than `HighDegreeInterpolationGate`.
#[derive(Copy, Clone, Debug)]
pub(crate) struct LowDegreeInterpolationGate<F: RichField + Extendable<D>, const D: usize> {
pub subgroup_bits: usize,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize> InterpolationGate<F, D>
for LowDegreeInterpolationGate<F, D>
{
fn new(subgroup_bits: usize) -> Self {
Self {
subgroup_bits,
_phantom: PhantomData,
}
}
fn num_points(&self) -> usize {
1 << self.subgroup_bits
}
}
impl<F: RichField + Extendable<D>, const D: usize> LowDegreeInterpolationGate<F, D> {
/// `powers_shift(i)` is the wire index of `wire_shift^i`.
pub fn powers_shift(&self, i: usize) -> usize {
debug_assert!(0 < i && i < self.num_points());
if i == 1 {
return self.wire_shift();
}
self.end_coeffs() + i - 2
}
/// `powers_evalutation_point(i)` is the wire index of `evalutation_point^i`.
pub fn powers_evaluation_point(&self, i: usize) -> Range<usize> {
debug_assert!(0 < i && i < self.num_points());
if i == 1 {
return self.wires_evaluation_point();
}
let start = self.end_coeffs() + self.num_points() - 2 + (i - 2) * D;
start..start + D
}
/// End of wire indices, exclusive.
fn end(&self) -> usize {
self.powers_evaluation_point(self.num_points() - 1).end
}
/// The domain of the points we're interpolating.
fn coset(&self, shift: F) -> impl Iterator<Item = F> {
let g = F::primitive_root_of_unity(self.subgroup_bits);
let size = 1 << self.subgroup_bits;
// Speed matters here, so we avoid `cyclic_subgroup_coset_known_order` which allocates.
g.powers().take(size).map(move |x| x * shift)
}
}
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for LowDegreeInterpolationGate<F, D> {
fn id(&self) -> String {
format!("{:?}<D={}>", self, D)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let mut constraints = Vec::with_capacity(self.num_constraints());
let coeffs = (0..self.num_points())
.map(|i| vars.get_local_ext_algebra(self.wires_coeff(i)))
.collect::<Vec<_>>();
let mut powers_shift = (1..self.num_points())
.map(|i| vars.local_wires[self.powers_shift(i)])
.collect::<Vec<_>>();
let shift = powers_shift[0];
for i in 1..self.num_points() - 1 {
constraints.push(powers_shift[i - 1] * shift - powers_shift[i]);
}
powers_shift.insert(0, F::Extension::ONE);
// `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient.
// Then, `altered(w^i) = original(shift*w^i)`.
let altered_coeffs = coeffs
.iter()
.zip(powers_shift)
.map(|(&c, p)| c.scalar_mul(p))
.collect::<Vec<_>>();
let interpolant = PolynomialCoeffsAlgebra::new(coeffs);
let altered_interpolant = PolynomialCoeffsAlgebra::new(altered_coeffs);
for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits)
.into_iter()
.enumerate()
{
let value = vars.get_local_ext_algebra(self.wires_value(i));
let computed_value = altered_interpolant.eval_base(point);
constraints.extend(&(value - computed_value).to_basefield_array());
}
let evaluation_point_powers = (1..self.num_points())
.map(|i| vars.get_local_ext_algebra(self.powers_evaluation_point(i)))
.collect::<Vec<_>>();
let evaluation_point = evaluation_point_powers[0];
for i in 1..self.num_points() - 1 {
constraints.extend(
(evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i])
.to_basefield_array(),
);
}
let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value());
let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers);
constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array());
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints());
let coeffs = (0..self.num_points())
.map(|i| vars.get_local_ext(self.wires_coeff(i)))
.collect::<Vec<_>>();
let mut powers_shift = (1..self.num_points())
.map(|i| vars.local_wires[self.powers_shift(i)])
.collect::<Vec<_>>();
let shift = powers_shift[0];
for i in 1..self.num_points() - 1 {
constraints.push(powers_shift[i - 1] * shift - powers_shift[i]);
}
powers_shift.insert(0, F::ONE);
// `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient.
// Then, `altered(w^i) = original(shift*w^i)`.
let altered_coeffs = coeffs
.iter()
.zip(powers_shift)
.map(|(&c, p)| c.scalar_mul(p))
.collect::<Vec<_>>();
let interpolant = PolynomialCoeffs::new(coeffs);
let altered_interpolant = PolynomialCoeffs::new(altered_coeffs);
for (i, point) in F::two_adic_subgroup(self.subgroup_bits)
.into_iter()
.enumerate()
{
let value = vars.get_local_ext(self.wires_value(i));
let computed_value = altered_interpolant.eval_base(point);
constraints.extend(&(value - computed_value).to_basefield_array());
}
let evaluation_point_powers = (1..self.num_points())
.map(|i| vars.get_local_ext(self.powers_evaluation_point(i)))
.collect::<Vec<_>>();
let evaluation_point = evaluation_point_powers[0];
for i in 1..self.num_points() - 1 {
constraints.extend(
(evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i])
.to_basefield_array(),
);
}
let evaluation_value = vars.get_local_ext(self.wires_evaluation_value());
let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers);
constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array());
constraints
}
fn eval_unfiltered_recursively(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let mut constraints = Vec::with_capacity(self.num_constraints());
let coeffs = (0..self.num_points())
.map(|i| vars.get_local_ext_algebra(self.wires_coeff(i)))
.collect::<Vec<_>>();
let mut powers_shift = (1..self.num_points())
.map(|i| vars.local_wires[self.powers_shift(i)])
.collect::<Vec<_>>();
let shift = powers_shift[0];
for i in 1..self.num_points() - 1 {
constraints.push(builder.mul_sub_extension(
powers_shift[i - 1],
shift,
powers_shift[i],
));
}
powers_shift.insert(0, builder.one_extension());
// `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient.
// Then, `altered(w^i) = original(shift*w^i)`.
let altered_coeffs = coeffs
.iter()
.zip(powers_shift)
.map(|(&c, p)| builder.scalar_mul_ext_algebra(p, c))
.collect::<Vec<_>>();
let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs);
let altered_interpolant = PolynomialCoeffsExtAlgebraTarget(altered_coeffs);
for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits)
.into_iter()
.enumerate()
{
let value = vars.get_local_ext_algebra(self.wires_value(i));
let point = builder.constant_extension(point);
let computed_value = altered_interpolant.eval_scalar(builder, point);
constraints.extend(
&builder
.sub_ext_algebra(value, computed_value)
.to_ext_target_array(),
);
}
let evaluation_point_powers = (1..self.num_points())
.map(|i| vars.get_local_ext_algebra(self.powers_evaluation_point(i)))
.collect::<Vec<_>>();
let evaluation_point = evaluation_point_powers[0];
for i in 1..self.num_points() - 1 {
let neg_one_ext = builder.neg_one_extension();
let neg_new_power =
builder.scalar_mul_ext_algebra(neg_one_ext, evaluation_point_powers[i]);
let constraint = builder.mul_add_ext_algebra(
evaluation_point,
evaluation_point_powers[i - 1],
neg_new_power,
);
constraints.extend(constraint.to_ext_target_array());
}
let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value());
let computed_evaluation_value =
interpolant.eval_with_powers(builder, &evaluation_point_powers);
// let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point());
// let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value());
// let computed_evaluation_value = interpolant.eval(builder, evaluation_point);
constraints.extend(
&builder
.sub_ext_algebra(evaluation_value, computed_evaluation_value)
.to_ext_target_array(),
);
constraints
}
fn generators(
&self,
gate_index: usize,
_local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
let gen = InterpolationGenerator::<F, D> {
gate_index,
gate: *self,
_phantom: PhantomData,
};
vec![Box::new(gen.adapter())]
}
fn num_wires(&self) -> usize {
self.end()
}
fn num_constants(&self) -> usize {
0
}
fn degree(&self) -> usize {
2
}
fn num_constraints(&self) -> usize {
// `num_points * D` constraints to check for consistency between the coefficients and the
// point-value pairs, plus `D` constraints for the evaluation value, plus `(D+1)*(num_points-2)`
// to check power constraints for evaluation point and shift.
self.num_points() * D + D + (D + 1) * (self.num_points() - 2)
}
}
#[derive(Debug)]
struct InterpolationGenerator<F: RichField + Extendable<D>, const D: usize> {
gate_index: usize,
gate: LowDegreeInterpolationGate<F, D>,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
for InterpolationGenerator<F, D>
{
fn dependencies(&self) -> Vec<Target> {
let local_target = |input| {
Target::Wire(Wire {
gate: self.gate_index,
input,
})
};
let local_targets = |inputs: Range<usize>| inputs.map(local_target);
let num_points = self.gate.num_points();
let mut deps = Vec::with_capacity(1 + D + num_points * D);
deps.push(local_target(self.gate.wire_shift()));
deps.extend(local_targets(self.gate.wires_evaluation_point()));
for i in 0..num_points {
deps.extend(local_targets(self.gate.wires_value(i)));
}
deps
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let local_wire = |input| Wire {
gate: self.gate_index,
input,
};
let get_local_wire = |input| witness.get_wire(local_wire(input));
let get_local_ext = |wire_range: Range<usize>| {
debug_assert_eq!(wire_range.len(), D);
let values = wire_range.map(get_local_wire).collect::<Vec<_>>();
let arr = values.try_into().unwrap();
F::Extension::from_basefield_array(arr)
};
let wire_shift = get_local_wire(self.gate.wire_shift());
for (i, power) in wire_shift
.powers()
.take(self.gate.num_points())
.enumerate()
.skip(2)
{
out_buffer.set_wire(local_wire(self.gate.powers_shift(i)), power);
}
// Compute the interpolant.
let points = self.gate.coset(wire_shift);
let points = points
.into_iter()
.enumerate()
.map(|(i, point)| (point.into(), get_local_ext(self.gate.wires_value(i))))
.collect::<Vec<_>>();
let interpolant = interpolant(&points);
for (i, &coeff) in interpolant.coeffs.iter().enumerate() {
let wires = self.gate.wires_coeff(i).map(local_wire);
out_buffer.set_ext_wires(wires, coeff);
}
let evaluation_point = get_local_ext(self.gate.wires_evaluation_point());
for (i, power) in evaluation_point
.powers()
.take(self.gate.num_points())
.enumerate()
.skip(2)
{
out_buffer.set_extension_target(
ExtensionTarget::from_range(self.gate_index, self.gate.powers_evaluation_point(i)),
power,
);
}
let evaluation_value = interpolant.eval(evaluation_point);
let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire);
out_buffer.set_ext_wires(evaluation_value_wires, evaluation_value);
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use crate::field::extension_field::quadratic::QuadraticExtension;
use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField;
use crate::gadgets::interpolation::InterpolationGate;
use crate::gates::gate::Gate;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate;
use crate::hash::hash_types::HashOut;
use crate::plonk::vars::EvaluationVars;
use crate::polynomial::PolynomialCoeffs;
#[test]
fn low_degree() {
test_low_degree::<GoldilocksField, _, 4>(LowDegreeInterpolationGate::new(4));
}
#[test]
fn eval_fns() -> Result<()> {
test_eval_fns::<GoldilocksField, _, 4>(LowDegreeInterpolationGate::new(4))
}
#[test]
fn test_gate_constraint() {
type F = GoldilocksField;
type FF = QuadraticExtension<GoldilocksField>;
const D: usize = 2;
/// Returns the local wires for an interpolation gate for given coeffs, points and eval point.
fn get_wires(
gate: &LowDegreeInterpolationGate<F, D>,
shift: F,
coeffs: PolynomialCoeffs<FF>,
eval_point: FF,
) -> Vec<FF> {
let points = gate.coset(shift);
let mut v = vec![shift];
for x in points {
v.extend(coeffs.eval(x.into()).0);
}
v.extend(eval_point.0);
v.extend(coeffs.eval(eval_point).0);
for i in 0..coeffs.len() {
v.extend(coeffs.coeffs[i].0);
}
v.extend(shift.powers().skip(2).take(gate.num_points() - 2));
v.extend(
eval_point
.powers()
.skip(2)
.take(gate.num_points() - 2)
.flat_map(|ff| ff.0),
);
v.iter().map(|&x| x.into()).collect::<Vec<_>>()
}
// Get a working row for LowDegreeInterpolationGate.
let subgroup_bits = 4;
let shift = F::rand();
let coeffs = PolynomialCoeffs::new(FF::rand_vec(1 << subgroup_bits));
let eval_point = FF::rand();
let gate = LowDegreeInterpolationGate::<F, D>::new(subgroup_bits);
let vars = EvaluationVars {
local_constants: &[],
local_wires: &get_wires(&gate, shift, coeffs, eval_point),
public_inputs_hash: &HashOut::rand(),
};
assert!(
gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()),
"Gate constraints are not satisfied."
);
}
}

View File

@ -1,8 +1,10 @@
// Gates have `new` methods that return `GateRef`s.
#![allow(clippy::new_ret_no_self)]
pub mod arithmetic;
pub mod arithmetic_base;
pub mod arithmetic_extension;
pub mod arithmetic_u32;
pub mod assert_le;
pub mod base_sum;
pub mod comparison;
pub mod constant;
@ -12,12 +14,16 @@ pub mod gate_tree;
pub mod gmimc;
pub mod insertion;
pub mod interpolation;
pub mod low_degree_interpolation;
pub mod multiplication_extension;
pub mod noop;
pub mod poseidon;
pub(crate) mod poseidon_mds;
pub(crate) mod public_input;
pub mod random_access;
pub mod reducing;
pub mod reducing_extension;
pub mod subtraction_u32;
pub mod switch;
#[cfg(test)]

View File

@ -0,0 +1,204 @@
use std::ops::Range;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::extension_field::FieldExtension;
use crate::field::field_types::RichField;
use crate::gates::gate::Gate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// A gate which can perform a weighted multiplication, i.e. `result = c0 x y`. If the config
/// supports enough routed wires, it can support several such operations in one gate.
#[derive(Debug)]
pub struct MulExtensionGate<const D: usize> {
/// Number of multiplications performed by the gate.
pub num_ops: usize,
}
impl<const D: usize> MulExtensionGate<D> {
pub fn new_from_config(config: &CircuitConfig) -> Self {
Self {
num_ops: Self::num_ops(config),
}
}
/// Determine the maximum number of operations that can fit in one gate for the given config.
pub(crate) fn num_ops(config: &CircuitConfig) -> usize {
let wires_per_op = 3 * D;
config.num_routed_wires / wires_per_op
}
pub fn wires_ith_multiplicand_0(i: usize) -> Range<usize> {
3 * D * i..3 * D * i + D
}
pub fn wires_ith_multiplicand_1(i: usize) -> Range<usize> {
3 * D * i + D..3 * D * i + 2 * D
}
pub fn wires_ith_output(i: usize) -> Range<usize> {
3 * D * i + 2 * D..3 * D * i + 3 * D
}
}
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for MulExtensionGate<D> {
fn id(&self) -> String {
format!("{:?}", self)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let const_0 = vars.local_constants[0];
let mut constraints = Vec::new();
for i in 0..self.num_ops {
let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i));
let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i));
let output = vars.get_local_ext_algebra(Self::wires_ith_output(i));
let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(const_0);
constraints.extend((output - computed_output).to_basefield_array());
}
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let const_0 = vars.local_constants[0];
let mut constraints = Vec::new();
for i in 0..self.num_ops {
let multiplicand_0 = vars.get_local_ext(Self::wires_ith_multiplicand_0(i));
let multiplicand_1 = vars.get_local_ext(Self::wires_ith_multiplicand_1(i));
let output = vars.get_local_ext(Self::wires_ith_output(i));
let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(const_0);
constraints.extend((output - computed_output).to_basefield_array());
}
constraints
}
fn eval_unfiltered_recursively(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let const_0 = vars.local_constants[0];
let mut constraints = Vec::new();
for i in 0..self.num_ops {
let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i));
let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i));
let output = vars.get_local_ext_algebra(Self::wires_ith_output(i));
let computed_output = {
let mul = builder.mul_ext_algebra(multiplicand_0, multiplicand_1);
builder.scalar_mul_ext_algebra(const_0, mul)
};
let diff = builder.sub_ext_algebra(output, computed_output);
constraints.extend(diff.to_ext_target_array());
}
constraints
}
fn generators(
&self,
gate_index: usize,
local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
(0..self.num_ops)
.map(|i| {
let g: Box<dyn WitnessGenerator<F>> = Box::new(
MulExtensionGenerator {
gate_index,
const_0: local_constants[0],
i,
}
.adapter(),
);
g
})
.collect::<Vec<_>>()
}
fn num_wires(&self) -> usize {
self.num_ops * 3 * D
}
fn num_constants(&self) -> usize {
1
}
fn degree(&self) -> usize {
3
}
fn num_constraints(&self) -> usize {
self.num_ops * D
}
}
#[derive(Clone, Debug)]
struct MulExtensionGenerator<F: RichField + Extendable<D>, const D: usize> {
gate_index: usize,
const_0: F,
i: usize,
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
for MulExtensionGenerator<F, D>
{
fn dependencies(&self) -> Vec<Target> {
MulExtensionGate::<D>::wires_ith_multiplicand_0(self.i)
.chain(MulExtensionGate::<D>::wires_ith_multiplicand_1(self.i))
.map(|i| Target::wire(self.gate_index, i))
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let extract_extension = |range: Range<usize>| -> F::Extension {
let t = ExtensionTarget::from_range(self.gate_index, range);
witness.get_extension_target(t)
};
let multiplicand_0 =
extract_extension(MulExtensionGate::<D>::wires_ith_multiplicand_0(self.i));
let multiplicand_1 =
extract_extension(MulExtensionGate::<D>::wires_ith_multiplicand_1(self.i));
let output_target = ExtensionTarget::from_range(
self.gate_index,
MulExtensionGate::<D>::wires_ith_output(self.i),
);
let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(self.const_0);
out_buffer.set_extension_target(output_target, computed_output)
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::gates::multiplication_extension::MulExtensionGate;
use crate::plonk::circuit_data::CircuitConfig;
#[test]
fn low_degree() {
let gate = MulExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config());
test_low_degree::<GoldilocksField, _, 4>(gate);
}
#[test]
fn eval_fns() -> Result<()> {
let gate = MulExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config());
test_eval_fns::<GoldilocksField, _, 4>(gate)
}
}

View File

@ -1,4 +1,3 @@
use std::convert::TryInto;
use std::marker::PhantomData;
use crate::field::extension_field::target::ExtensionTarget;
@ -47,44 +46,59 @@ impl<F: Extendable<D>, const D: usize> PoseidonGate<F, D> {
/// is useful for ordering hashes in Merkle proofs. Otherwise, this should be set to 0.
pub const WIRE_SWAP: usize = 2 * SPONGE_WIDTH;
const START_DELTA: usize = 2 * WIDTH + 1;
/// A wire which stores `swap * (input[i + 4] - input[i])`; used to compute the swapped inputs.
fn wire_delta(i: usize) -> usize {
assert!(i < 4);
Self::START_DELTA + i
}
const START_FULL_0: usize = Self::START_DELTA + 4;
/// A wire which stores the input of the `i`-th S-box of the `round`-th round of the first set
/// of full rounds.
fn wire_full_sbox_0(round: usize, i: usize) -> usize {
debug_assert!(
round != 0,
"First round S-box inputs are not stored as wires"
);
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH);
2 * SPONGE_WIDTH + 1 + SPONGE_WIDTH * round + i
debug_assert!(i < WIDTH);
Self::START_FULL_0 + WIDTH * (round - 1) + i
}
const START_PARTIAL: usize = Self::START_FULL_0 + WIDTH * (poseidon::HALF_N_FULL_ROUNDS - 1);
/// A wire which stores the input of the S-box of the `round`-th round of the partial rounds.
fn wire_partial_sbox(round: usize) -> usize {
debug_assert!(round < poseidon::N_PARTIAL_ROUNDS);
2 * SPONGE_WIDTH + 1 + SPONGE_WIDTH * poseidon::HALF_N_FULL_ROUNDS + round
Self::START_PARTIAL + round
}
const START_FULL_1: usize = Self::START_PARTIAL + poseidon::N_PARTIAL_ROUNDS;
/// A wire which stores the input of the `i`-th S-box of the `round`-th round of the second set
/// of full rounds.
fn wire_full_sbox_1(round: usize, i: usize) -> usize {
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH);
2 * SPONGE_WIDTH
+ 1
+ SPONGE_WIDTH * (poseidon::HALF_N_FULL_ROUNDS + round)
+ poseidon::N_PARTIAL_ROUNDS
+ i
debug_assert!(i < WIDTH);
Self::START_FULL_1 + WIDTH * round + i
}
/// End of wire indices, exclusive.
fn end() -> usize {
2 * SPONGE_WIDTH
+ 1
+ SPONGE_WIDTH * poseidon::N_FULL_ROUNDS_TOTAL
+ poseidon::N_PARTIAL_ROUNDS
Self::START_FULL_1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS
}
}
impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize> Gate<F, D>
for PoseidonGate<F, D, WIDTH>
where
[(); WIDTH - 1]:,
{
fn id(&self) -> String {
format!("{:?}<SPONGE_WIDTH={}>", self, SPONGE_WIDTH)
format!("{:?}<WIDTH={}>", self, WIDTH)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
@ -94,69 +108,79 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(swap * (swap - F::Extension::ONE));
let mut state = Vec::with_capacity(SPONGE_WIDTH);
// Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`.
for i in 0..4 {
let a = vars.local_wires[i];
let b = vars.local_wires[i + 4];
state.push(a + swap * (b - a));
}
for i in 0..4 {
let a = vars.local_wires[i + 4];
let b = vars.local_wires[i];
state.push(a + swap * (b - a));
}
for i in 8..SPONGE_WIDTH {
state.push(vars.local_wires[i]);
let input_lhs = vars.local_wires[Self::wire_input(i)];
let input_rhs = vars.local_wires[Self::wire_input(i + 4)];
let delta_i = vars.local_wires[Self::wire_delta(i)];
constraints.push(swap * (input_rhs - input_lhs) - delta_i);
}
// Compute the possibly-swapped input layer.
let mut state = [F::Extension::ZERO; WIDTH];
for i in 0..4 {
let delta_i = vars.local_wires[Self::wire_delta(i)];
let input_lhs = Self::wire_input(i);
let input_rhs = Self::wire_input(i + 4);
state[i] = vars.local_wires[input_lhs] + delta_i;
state[i + 4] = vars.local_wires[input_rhs] - delta_i;
}
for i in 8..WIDTH {
state[i] = vars.local_wires[Self::wire_input(i)];
}
let mut state: [F::Extension; SPONGE_WIDTH] = state.try_into().unwrap();
let mut round_ctr = 0;
// First set of full rounds.
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon>::constant_layer_field(&mut state, round_ctr);
for i in 0..SPONGE_WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
constraints.push(state[i] - sbox_in);
state[i] = sbox_in;
<F as Poseidon<WIDTH>>::constant_layer_field(&mut state, round_ctr);
if r != 0 {
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
constraints.push(state[i] - sbox_in);
state[i] = sbox_in;
}
}
<F as Poseidon>::sbox_layer_field(&mut state);
state = <F as Poseidon>::mds_layer_field(&state);
<F as Poseidon<WIDTH>>::sbox_layer_field(&mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
round_ctr += 1;
}
// Partial rounds.
<F as Poseidon>::partial_first_constant_layer(&mut state);
state = <F as Poseidon>::mds_partial_layer_init(&mut state);
<F as Poseidon<WIDTH>>::partial_first_constant_layer(&mut state);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init(&state);
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
constraints.push(state[0] - sbox_in);
state[0] = <F as Poseidon>::sbox_monomial(sbox_in);
state[0] +=
F::Extension::from_canonical_u64(<F as Poseidon>::FAST_PARTIAL_ROUND_CONSTANTS[r]);
state = <F as Poseidon>::mds_partial_layer_fast_field(&state, r);
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
state[0] += F::Extension::from_canonical_u64(
<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r],
);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(&state, r);
}
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
constraints.push(state[0] - sbox_in);
state[0] = <F as Poseidon>::sbox_monomial(sbox_in);
state =
<F as Poseidon>::mds_partial_layer_fast_field(&state, poseidon::N_PARTIAL_ROUNDS - 1);
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(
&state,
poseidon::N_PARTIAL_ROUNDS - 1,
);
round_ctr += poseidon::N_PARTIAL_ROUNDS;
// Second set of full rounds.
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon>::constant_layer_field(&mut state, round_ctr);
for i in 0..SPONGE_WIDTH {
<F as Poseidon<WIDTH>>::constant_layer_field(&mut state, round_ctr);
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)];
constraints.push(state[i] - sbox_in);
state[i] = sbox_in;
}
<F as Poseidon>::sbox_layer_field(&mut state);
state = <F as Poseidon>::mds_layer_field(&state);
<F as Poseidon<WIDTH>>::sbox_layer_field(&mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
round_ctr += 1;
}
for i in 0..SPONGE_WIDTH {
for i in 0..WIDTH {
constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]);
}
@ -170,67 +194,76 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(swap * swap.sub_one());
let mut state = Vec::with_capacity(SPONGE_WIDTH);
// Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`.
for i in 0..4 {
let a = vars.local_wires[i];
let b = vars.local_wires[i + 4];
state.push(a + swap * (b - a));
}
for i in 0..4 {
let a = vars.local_wires[i + 4];
let b = vars.local_wires[i];
state.push(a + swap * (b - a));
}
for i in 8..SPONGE_WIDTH {
state.push(vars.local_wires[i]);
let input_lhs = vars.local_wires[Self::wire_input(i)];
let input_rhs = vars.local_wires[Self::wire_input(i + 4)];
let delta_i = vars.local_wires[Self::wire_delta(i)];
constraints.push(swap * (input_rhs - input_lhs) - delta_i);
}
// Compute the possibly-swapped input layer.
let mut state = [F::ZERO; WIDTH];
for i in 0..4 {
let delta_i = vars.local_wires[Self::wire_delta(i)];
let input_lhs = Self::wire_input(i);
let input_rhs = Self::wire_input(i + 4);
state[i] = vars.local_wires[input_lhs] + delta_i;
state[i + 4] = vars.local_wires[input_rhs] - delta_i;
}
for i in 8..WIDTH {
state[i] = vars.local_wires[Self::wire_input(i)];
}
let mut state: [F; SPONGE_WIDTH] = state.try_into().unwrap();
let mut round_ctr = 0;
// First set of full rounds.
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon>::constant_layer(&mut state, round_ctr);
for i in 0..SPONGE_WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
constraints.push(state[i] - sbox_in);
state[i] = sbox_in;
<F as Poseidon<WIDTH>>::constant_layer(&mut state, round_ctr);
if r != 0 {
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
constraints.push(state[i] - sbox_in);
state[i] = sbox_in;
}
}
<F as Poseidon>::sbox_layer(&mut state);
state = <F as Poseidon>::mds_layer(&state);
<F as Poseidon<WIDTH>>::sbox_layer(&mut state);
state = <F as Poseidon<WIDTH>>::mds_layer(&state);
round_ctr += 1;
}
// Partial rounds.
<F as Poseidon>::partial_first_constant_layer(&mut state);
state = <F as Poseidon>::mds_partial_layer_init(&mut state);
<F as Poseidon<WIDTH>>::partial_first_constant_layer(&mut state);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init(&state);
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
constraints.push(state[0] - sbox_in);
state[0] = <F as Poseidon>::sbox_monomial(sbox_in);
state[0] += F::from_canonical_u64(<F as Poseidon>::FAST_PARTIAL_ROUND_CONSTANTS[r]);
state = <F as Poseidon>::mds_partial_layer_fast(&state, r);
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
state[0] +=
F::from_canonical_u64(<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r]);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast(&state, r);
}
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
constraints.push(state[0] - sbox_in);
state[0] = <F as Poseidon>::sbox_monomial(sbox_in);
state = <F as Poseidon>::mds_partial_layer_fast(&state, poseidon::N_PARTIAL_ROUNDS - 1);
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
state =
<F as Poseidon<WIDTH>>::mds_partial_layer_fast(&state, poseidon::N_PARTIAL_ROUNDS - 1);
round_ctr += poseidon::N_PARTIAL_ROUNDS;
// Second set of full rounds.
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon>::constant_layer(&mut state, round_ctr);
for i in 0..SPONGE_WIDTH {
<F as Poseidon<WIDTH>>::constant_layer(&mut state, round_ctr);
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)];
constraints.push(state[i] - sbox_in);
state[i] = sbox_in;
}
<F as Poseidon>::sbox_layer(&mut state);
state = <F as Poseidon>::mds_layer(&state);
<F as Poseidon<WIDTH>>::sbox_layer(&mut state);
state = <F as Poseidon<WIDTH>>::mds_layer(&state);
round_ctr += 1;
}
for i in 0..SPONGE_WIDTH {
for i in 0..WIDTH {
constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]);
}
@ -244,7 +277,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
) -> Vec<ExtensionTarget<D>> {
// The naive method is more efficient if we have enough routed wires for PoseidonMdsGate.
let use_mds_gate =
builder.config.num_routed_wires >= PoseidonMdsGate::<F, D>::new().num_wires();
builder.config.num_routed_wires >= PoseidonMdsGate::<F, D, WIDTH>::new().num_wires();
let mut constraints = Vec::with_capacity(self.num_constraints());
@ -252,71 +285,73 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(builder.mul_sub_extension(swap, swap, swap));
let mut state = Vec::with_capacity(SPONGE_WIDTH);
// We need to compute both `if swap {b} else {a}` and `if swap {a} else {b}`.
// We will arithmetize them as
// swap (b - a) + a
// -swap (b - a) + b
// so that `b - a` can be used for both.
let mut state_first_4 = vec![];
let mut state_next_4 = vec![];
// Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`.
for i in 0..4 {
let a = vars.local_wires[i];
let b = vars.local_wires[i + 4];
let delta = builder.sub_extension(b, a);
state_first_4.push(builder.mul_add_extension(swap, delta, a));
state_next_4.push(builder.arithmetic_extension(F::NEG_ONE, F::ONE, swap, delta, b));
let input_lhs = vars.local_wires[Self::wire_input(i)];
let input_rhs = vars.local_wires[Self::wire_input(i + 4)];
let delta_i = vars.local_wires[Self::wire_delta(i)];
let diff = builder.sub_extension(input_rhs, input_lhs);
constraints.push(builder.mul_sub_extension(swap, diff, delta_i));
}
state.extend(state_first_4);
state.extend(state_next_4);
for i in 8..SPONGE_WIDTH {
state.push(vars.local_wires[i]);
// Compute the possibly-swapped input layer.
let mut state = [builder.zero_extension(); WIDTH];
for i in 0..4 {
let delta_i = vars.local_wires[Self::wire_delta(i)];
let input_lhs = vars.local_wires[Self::wire_input(i)];
let input_rhs = vars.local_wires[Self::wire_input(i + 4)];
state[i] = builder.add_extension(input_lhs, delta_i);
state[i + 4] = builder.sub_extension(input_rhs, delta_i);
}
for i in 8..WIDTH {
state[i] = vars.local_wires[Self::wire_input(i)];
}
let mut state: [ExtensionTarget<D>; SPONGE_WIDTH] = state.try_into().unwrap();
let mut round_ctr = 0;
// First set of full rounds.
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon>::constant_layer_recursive(builder, &mut state, round_ctr);
for i in 0..SPONGE_WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
constraints.push(builder.sub_extension(state[i], sbox_in));
state[i] = sbox_in;
<F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
if r != 0 {
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
constraints.push(builder.sub_extension(state[i], sbox_in));
state[i] = sbox_in;
}
}
<F as Poseidon>::sbox_layer_recursive(builder, &mut state);
state = <F as Poseidon>::mds_layer_recursive(builder, &state);
<F as Poseidon<WIDTH>>::sbox_layer_recursive(builder, &mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_recursive(builder, &state);
round_ctr += 1;
}
// Partial rounds.
if use_mds_gate {
for r in 0..poseidon::N_PARTIAL_ROUNDS {
<F as Poseidon>::constant_layer_recursive(builder, &mut state, round_ctr);
<F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
constraints.push(builder.sub_extension(state[0], sbox_in));
state[0] = <F as Poseidon>::sbox_monomial_recursive(builder, sbox_in);
state = <F as Poseidon>::mds_layer_recursive(builder, &state);
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
state = <F as Poseidon<WIDTH>>::mds_layer_recursive(builder, &state);
round_ctr += 1;
}
} else {
<F as Poseidon>::partial_first_constant_layer_recursive(builder, &mut state);
state = <F as Poseidon>::mds_partial_layer_init_recursive(builder, &mut state);
<F as Poseidon<WIDTH>>::partial_first_constant_layer_recursive(builder, &mut state);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init_recursive(builder, &state);
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
constraints.push(builder.sub_extension(state[0], sbox_in));
state[0] = <F as Poseidon>::sbox_monomial_recursive(builder, sbox_in);
state[0] = builder.add_const_extension(
state[0],
F::from_canonical_u64(<F as Poseidon>::FAST_PARTIAL_ROUND_CONSTANTS[r]),
);
state = <F as Poseidon>::mds_partial_layer_fast_recursive(builder, &state, r);
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
let c = <F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r];
let c = F::Extension::from_canonical_u64(c);
let c = builder.constant_extension(c);
state[0] = builder.add_extension(state[0], c);
state =
<F as Poseidon<WIDTH>>::mds_partial_layer_fast_recursive(builder, &state, r);
}
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
constraints.push(builder.sub_extension(state[0], sbox_in));
state[0] = <F as Poseidon>::sbox_monomial_recursive(builder, sbox_in);
state = <F as Poseidon>::mds_partial_layer_fast_recursive(
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_recursive(
builder,
&state,
poseidon::N_PARTIAL_ROUNDS - 1,
@ -326,18 +361,18 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
// Second set of full rounds.
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon>::constant_layer_recursive(builder, &mut state, round_ctr);
for i in 0..SPONGE_WIDTH {
<F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)];
constraints.push(builder.sub_extension(state[i], sbox_in));
state[i] = sbox_in;
}
<F as Poseidon>::sbox_layer_recursive(builder, &mut state);
state = <F as Poseidon>::mds_layer_recursive(builder, &state);
<F as Poseidon<WIDTH>>::sbox_layer_recursive(builder, &mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_recursive(builder, &state);
round_ctr += 1;
}
for i in 0..SPONGE_WIDTH {
for i in 0..WIDTH {
constraints
.push(builder.sub_extension(state[i], vars.local_wires[Self::wire_output(i)]));
}
@ -350,7 +385,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
gate_index: usize,
_local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
let gen = PoseidonGenerator::<F, D> {
let gen = PoseidonGenerator::<F, D, WIDTH> {
gate_index,
_phantom: PhantomData,
};
@ -370,23 +405,31 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
}
fn num_constraints(&self) -> usize {
SPONGE_WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + SPONGE_WIDTH + 1
WIDTH * (poseidon::N_FULL_ROUNDS_TOTAL - 1) + poseidon::N_PARTIAL_ROUNDS + WIDTH + 1 + 4
}
}
#[derive(Debug)]
struct PoseidonGenerator<F: Extendable<D> + Poseidon, const D: usize> {
struct PoseidonGenerator<
F: RichField + Extendable<D> + Poseidon<WIDTH>,
const D: usize,
const WIDTH: usize,
> where
[(); WIDTH - 1]:,
{
gate_index: usize,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D> + Poseidon, const D: usize> SimpleGenerator<F>
for PoseidonGenerator<F, D>
impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize>
SimpleGenerator<F> for PoseidonGenerator<F, D, WIDTH>
where
[(); WIDTH - 1]:,
{
fn dependencies(&self) -> Vec<Target> {
(0..SPONGE_WIDTH)
.map(|i| PoseidonGate::<F, D>::wire_input(i))
.chain(Some(PoseidonGate::<F, D>::WIRE_SWAP))
(0..WIDTH)
.map(|i| PoseidonGate::<F, D, WIDTH>::wire_input(i))
.chain(Some(PoseidonGate::<F, D, WIDTH>::WIRE_SWAP))
.map(|input| Target::wire(self.gate_index, input))
.collect()
}
@ -397,87 +440,94 @@ impl<F: RichField + Extendable<D> + Poseidon, const D: usize> SimpleGenerator<F>
input,
};
let mut state = (0..SPONGE_WIDTH)
.map(|i| {
witness.get_wire(Wire {
gate: self.gate_index,
input: PoseidonGate::<F, D>::wire_input(i),
})
})
let mut state = (0..WIDTH)
.map(|i| witness.get_wire(local_wire(PoseidonGate::<F, D, WIDTH>::wire_input(i))))
.collect::<Vec<_>>();
let swap_value = witness.get_wire(Wire {
gate: self.gate_index,
input: PoseidonGate::<F, D>::WIRE_SWAP,
});
let swap_value = witness.get_wire(local_wire(PoseidonGate::<F, D, WIDTH>::WIRE_SWAP));
debug_assert!(swap_value == F::ZERO || swap_value == F::ONE);
for i in 0..4 {
let delta_i = swap_value * (state[i + 4] - state[i]);
out_buffer.set_wire(
local_wire(PoseidonGate::<F, D, WIDTH>::wire_delta(i)),
delta_i,
);
}
if swap_value == F::ONE {
for i in 0..4 {
state.swap(i, 4 + i);
}
}
let mut state: [F; SPONGE_WIDTH] = state.try_into().unwrap();
let mut state: [F; WIDTH] = state.try_into().unwrap();
let mut round_ctr = 0;
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon>::constant_layer_field(&mut state, round_ctr);
for i in 0..SPONGE_WIDTH {
out_buffer.set_wire(
local_wire(PoseidonGate::<F, D>::wire_full_sbox_0(r, i)),
state[i],
);
<F as Poseidon<WIDTH>>::constant_layer_field(&mut state, round_ctr);
if r != 0 {
for i in 0..WIDTH {
out_buffer.set_wire(
local_wire(PoseidonGate::<F, D, WIDTH>::wire_full_sbox_0(r, i)),
state[i],
);
}
}
<F as Poseidon>::sbox_layer_field(&mut state);
state = <F as Poseidon>::mds_layer_field(&state);
<F as Poseidon<WIDTH>>::sbox_layer_field(&mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
round_ctr += 1;
}
<F as Poseidon>::partial_first_constant_layer(&mut state);
state = <F as Poseidon>::mds_partial_layer_init(&mut state);
<F as Poseidon<WIDTH>>::partial_first_constant_layer(&mut state);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init(&state);
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
out_buffer.set_wire(
local_wire(PoseidonGate::<F, D>::wire_partial_sbox(r)),
local_wire(PoseidonGate::<F, D, WIDTH>::wire_partial_sbox(r)),
state[0],
);
state[0] = <F as Poseidon>::sbox_monomial(state[0]);
state[0] += F::from_canonical_u64(<F as Poseidon>::FAST_PARTIAL_ROUND_CONSTANTS[r]);
state = <F as Poseidon>::mds_partial_layer_fast_field(&state, r);
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(state[0]);
state[0] +=
F::from_canonical_u64(<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r]);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(&state, r);
}
out_buffer.set_wire(
local_wire(PoseidonGate::<F, D>::wire_partial_sbox(
local_wire(PoseidonGate::<F, D, WIDTH>::wire_partial_sbox(
poseidon::N_PARTIAL_ROUNDS - 1,
)),
state[0],
);
state[0] = <F as Poseidon>::sbox_monomial(state[0]);
state =
<F as Poseidon>::mds_partial_layer_fast_field(&state, poseidon::N_PARTIAL_ROUNDS - 1);
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(state[0]);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(
&state,
poseidon::N_PARTIAL_ROUNDS - 1,
);
round_ctr += poseidon::N_PARTIAL_ROUNDS;
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon>::constant_layer_field(&mut state, round_ctr);
for i in 0..SPONGE_WIDTH {
<F as Poseidon<WIDTH>>::constant_layer_field(&mut state, round_ctr);
for i in 0..WIDTH {
out_buffer.set_wire(
local_wire(PoseidonGate::<F, D>::wire_full_sbox_1(r, i)),
local_wire(PoseidonGate::<F, D, WIDTH>::wire_full_sbox_1(r, i)),
state[i],
);
}
<F as Poseidon>::sbox_layer_field(&mut state);
state = <F as Poseidon>::mds_layer_field(&state);
<F as Poseidon<WIDTH>>::sbox_layer_field(&mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
round_ctr += 1;
}
for i in 0..SPONGE_WIDTH {
out_buffer.set_wire(local_wire(PoseidonGate::<F, D>::wire_output(i)), state[i]);
for i in 0..WIDTH {
out_buffer.set_wire(
local_wire(PoseidonGate::<F, D, WIDTH>::wire_output(i)),
state[i],
);
}
}
}
#[cfg(test)]
mod tests {
use std::convert::TryInto;
use anyhow::Result;
use crate::field::field_types::Field;
@ -493,6 +543,29 @@ mod tests {
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
#[test]
fn wire_indices() {
type F = GoldilocksField;
const WIDTH: usize = 12;
type Gate = PoseidonGate<F, 4, WIDTH>;
assert_eq!(Gate::wire_input(0), 0);
assert_eq!(Gate::wire_input(11), 11);
assert_eq!(Gate::wire_output(0), 12);
assert_eq!(Gate::wire_output(11), 23);
assert_eq!(Gate::WIRE_SWAP, 24);
assert_eq!(Gate::wire_delta(0), 25);
assert_eq!(Gate::wire_delta(3), 28);
assert_eq!(Gate::wire_full_sbox_0(1, 0), 29);
assert_eq!(Gate::wire_full_sbox_0(3, 0), 53);
assert_eq!(Gate::wire_full_sbox_0(3, 11), 64);
assert_eq!(Gate::wire_partial_sbox(0), 65);
assert_eq!(Gate::wire_partial_sbox(21), 86);
assert_eq!(Gate::wire_full_sbox_1(0, 0), 87);
assert_eq!(Gate::wire_full_sbox_1(3, 0), 123);
assert_eq!(Gate::wire_full_sbox_1(3, 11), 134);
}
#[test]
fn generated_output() {
const D: usize = 2;

View File

@ -1,4 +1,3 @@
use std::convert::TryInto;
use std::marker::PhantomData;
use std::ops::Range;
@ -6,9 +5,8 @@ use crate::field::extension_field::algebra::ExtensionAlgebra;
use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget};
use crate::field::extension_field::Extendable;
use crate::field::extension_field::FieldExtension;
use crate::field::field_types::Field;
use crate::field::field_types::{Field, RichField};
use crate::gates::gate::Gate;
use crate::hash::hashing::SPONGE_WIDTH;
use crate::hash::poseidon::Poseidon;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
@ -17,11 +15,21 @@ use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
#[derive(Debug)]
pub struct PoseidonMdsGate<F: Extendable<D> + Poseidon, const D: usize> {
pub struct PoseidonMdsGate<
F: RichField + Extendable<D> + Poseidon<WIDTH>,
const D: usize,
const WIDTH: usize,
> where
[(); WIDTH - 1]:,
{
_phantom: PhantomData<F>,
}
impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize>
PoseidonMdsGate<F, D, WIDTH>
where
[(); WIDTH - 1]:,
{
pub fn new() -> Self {
PoseidonMdsGate {
_phantom: PhantomData,
@ -29,13 +37,13 @@ impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
}
pub fn wires_input(i: usize) -> Range<usize> {
assert!(i < SPONGE_WIDTH);
assert!(i < WIDTH);
i * D..(i + 1) * D
}
pub fn wires_output(i: usize) -> Range<usize> {
assert!(i < SPONGE_WIDTH);
(SPONGE_WIDTH + i) * D..(SPONGE_WIDTH + i + 1) * D
assert!(i < WIDTH);
(WIDTH + i) * D..(WIDTH + i + 1) * D
}
// Following are methods analogous to ones in `Poseidon`, but for extension algebras.
@ -43,14 +51,15 @@ impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
/// Same as `mds_row_shf` for an extension algebra of `F`.
fn mds_row_shf_algebra(
r: usize,
v: &[ExtensionAlgebra<F::Extension, D>; SPONGE_WIDTH],
v: &[ExtensionAlgebra<F::Extension, D>; WIDTH],
) -> ExtensionAlgebra<F::Extension, D> {
debug_assert!(r < SPONGE_WIDTH);
debug_assert!(r < WIDTH);
let mut res = ExtensionAlgebra::ZERO;
for i in 0..SPONGE_WIDTH {
let coeff = F::Extension::from_canonical_u64(1 << <F as Poseidon>::MDS_MATRIX_EXPS[i]);
res += v[(i + r) % SPONGE_WIDTH].scalar_mul(coeff);
for i in 0..WIDTH {
let coeff =
F::Extension::from_canonical_u64(1 << <F as Poseidon<WIDTH>>::MDS_MATRIX_EXPS[i]);
res += v[(i + r) % WIDTH].scalar_mul(coeff);
}
res
@ -60,16 +69,16 @@ impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
fn mds_row_shf_algebra_recursive(
builder: &mut CircuitBuilder<F, D>,
r: usize,
v: &[ExtensionAlgebraTarget<D>; SPONGE_WIDTH],
v: &[ExtensionAlgebraTarget<D>; WIDTH],
) -> ExtensionAlgebraTarget<D> {
debug_assert!(r < SPONGE_WIDTH);
debug_assert!(r < WIDTH);
let mut res = builder.zero_ext_algebra();
for i in 0..SPONGE_WIDTH {
for i in 0..WIDTH {
let coeff = builder.constant_extension(F::Extension::from_canonical_u64(
1 << <F as Poseidon>::MDS_MATRIX_EXPS[i],
1 << <F as Poseidon<WIDTH>>::MDS_MATRIX_EXPS[i],
));
res = builder.scalar_mul_add_ext_algebra(coeff, v[(i + r) % SPONGE_WIDTH], res);
res = builder.scalar_mul_add_ext_algebra(coeff, v[(i + r) % WIDTH], res);
}
res
@ -77,11 +86,11 @@ impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
/// Same as `mds_layer` for an extension algebra of `F`.
fn mds_layer_algebra(
state: &[ExtensionAlgebra<F::Extension, D>; SPONGE_WIDTH],
) -> [ExtensionAlgebra<F::Extension, D>; SPONGE_WIDTH] {
let mut result = [ExtensionAlgebra::ZERO; SPONGE_WIDTH];
state: &[ExtensionAlgebra<F::Extension, D>; WIDTH],
) -> [ExtensionAlgebra<F::Extension, D>; WIDTH] {
let mut result = [ExtensionAlgebra::ZERO; WIDTH];
for r in 0..SPONGE_WIDTH {
for r in 0..WIDTH {
result[r] = Self::mds_row_shf_algebra(r, state);
}
@ -91,11 +100,11 @@ impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
/// Same as `mds_layer_recursive` for an extension algebra of `F`.
fn mds_layer_algebra_recursive(
builder: &mut CircuitBuilder<F, D>,
state: &[ExtensionAlgebraTarget<D>; SPONGE_WIDTH],
) -> [ExtensionAlgebraTarget<D>; SPONGE_WIDTH] {
let mut result = [builder.zero_ext_algebra(); SPONGE_WIDTH];
state: &[ExtensionAlgebraTarget<D>; WIDTH],
) -> [ExtensionAlgebraTarget<D>; WIDTH] {
let mut result = [builder.zero_ext_algebra(); WIDTH];
for r in 0..SPONGE_WIDTH {
for r in 0..WIDTH {
result[r] = Self::mds_row_shf_algebra_recursive(builder, r, state);
}
@ -103,13 +112,17 @@ impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
}
}
impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate<F, D> {
impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize> Gate<F, D>
for PoseidonMdsGate<F, D, WIDTH>
where
[(); WIDTH - 1]:,
{
fn id(&self) -> String {
format!("{:?}<WIDTH={}>", self, SPONGE_WIDTH)
format!("{:?}<WIDTH={}>", self, WIDTH)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH)
let inputs: [_; WIDTH] = (0..WIDTH)
.map(|i| vars.get_local_ext_algebra(Self::wires_input(i)))
.collect::<Vec<_>>()
.try_into()
@ -117,7 +130,7 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
let computed_outputs = Self::mds_layer_algebra(&inputs);
(0..SPONGE_WIDTH)
(0..WIDTH)
.map(|i| vars.get_local_ext_algebra(Self::wires_output(i)))
.zip(computed_outputs)
.flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array())
@ -125,7 +138,7 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH)
let inputs: [_; WIDTH] = (0..WIDTH)
.map(|i| vars.get_local_ext(Self::wires_input(i)))
.collect::<Vec<_>>()
.try_into()
@ -133,7 +146,7 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
let computed_outputs = F::mds_layer_field(&inputs);
(0..SPONGE_WIDTH)
(0..WIDTH)
.map(|i| vars.get_local_ext(Self::wires_output(i)))
.zip(computed_outputs)
.flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array())
@ -145,7 +158,7 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH)
let inputs: [_; WIDTH] = (0..WIDTH)
.map(|i| vars.get_local_ext_algebra(Self::wires_input(i)))
.collect::<Vec<_>>()
.try_into()
@ -153,7 +166,7 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
let computed_outputs = Self::mds_layer_algebra_recursive(builder, &inputs);
(0..SPONGE_WIDTH)
(0..WIDTH)
.map(|i| vars.get_local_ext_algebra(Self::wires_output(i)))
.zip(computed_outputs)
.flat_map(|(out, computed_out)| {
@ -169,12 +182,12 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
gate_index: usize,
_local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
let gen = PoseidonMdsGenerator::<D> { gate_index };
let gen = PoseidonMdsGenerator::<D, WIDTH> { gate_index };
vec![Box::new(gen.adapter())]
}
fn num_wires(&self) -> usize {
2 * D * SPONGE_WIDTH
2 * D * WIDTH
}
fn num_constants(&self) -> usize {
@ -186,20 +199,30 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
}
fn num_constraints(&self) -> usize {
SPONGE_WIDTH * D
WIDTH * D
}
}
#[derive(Clone, Debug)]
struct PoseidonMdsGenerator<const D: usize> {
struct PoseidonMdsGenerator<const D: usize, const WIDTH: usize>
where
[(); WIDTH - 1]:,
{
gate_index: usize,
}
impl<F: Extendable<D> + Poseidon, const D: usize> SimpleGenerator<F> for PoseidonMdsGenerator<D> {
impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize>
SimpleGenerator<F> for PoseidonMdsGenerator<D, WIDTH>
where
[(); WIDTH - 1]:,
{
fn dependencies(&self) -> Vec<Target> {
(0..SPONGE_WIDTH)
(0..WIDTH)
.flat_map(|i| {
Target::wires_from_range(self.gate_index, PoseidonMdsGate::<F, D>::wires_input(i))
Target::wires_from_range(
self.gate_index,
PoseidonMdsGate::<F, D, WIDTH>::wires_input(i),
)
})
.collect()
}
@ -210,8 +233,8 @@ impl<F: Extendable<D> + Poseidon, const D: usize> SimpleGenerator<F> for Poseido
let get_local_ext =
|wire_range| witness.get_extension_target(get_local_get_target(wire_range));
let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH)
.map(|i| get_local_ext(PoseidonMdsGate::<F, D>::wires_input(i)))
let inputs: [_; WIDTH] = (0..WIDTH)
.map(|i| get_local_ext(PoseidonMdsGate::<F, D, WIDTH>::wires_input(i)))
.collect::<Vec<_>>()
.try_into()
.unwrap();
@ -220,7 +243,7 @@ impl<F: Extendable<D> + Poseidon, const D: usize> SimpleGenerator<F> for Poseido
for (i, &out) in outputs.iter().enumerate() {
out_buffer.set_extension_target(
get_local_get_target(PoseidonMdsGate::<F, D>::wires_output(i)),
get_local_get_target(PoseidonMdsGate::<F, D, WIDTH>::wires_output(i)),
out,
);
}
@ -232,21 +255,19 @@ mod tests {
use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::gates::poseidon_mds::PoseidonMdsGate;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use crate::hash::hashing::SPONGE_WIDTH;
#[test]
fn low_degree() {
type F = GoldilocksField;
let gate = PoseidonMdsGate::<F, 4>::new();
let gate = PoseidonMdsGate::<F, 4, SPONGE_WIDTH>::new();
test_low_degree(gate)
}
#[test]
fn eval_fns() -> anyhow::Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let gate = PoseidonMdsGate::<F, D>::new();
test_eval_fns::<F, C, _, D>(gate)
type F = GoldilocksField;
let gate = PoseidonMdsGate::<F, 4, SPONGE_WIDTH>::new();
test_eval_fns(gate)
}
}

View File

@ -1,5 +1,7 @@
use std::marker::PhantomData;
use itertools::Itertools;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
@ -14,76 +16,65 @@ use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// A gate for checking that a particular element of a list matches a given value.
#[derive(Copy, Clone, Debug)]
pub(crate) struct RandomAccessGate<F: Extendable<D>, const D: usize> {
pub vec_size: usize,
pub(crate) struct RandomAccessGate<F: Extendable<D>, const D: usize> {
pub bits: usize,
pub num_copies: usize,
_phantom: PhantomData<F>,
}
impl<F: Extendable<D>, const D: usize> RandomAccessGate<F, D> {
pub fn new(num_copies: usize, vec_size: usize) -> Self {
impl<F: Extendable<D>, const D: usize> RandomAccessGate<F, D> {
fn new(num_copies: usize, bits: usize) -> Self {
Self {
vec_size,
bits,
num_copies,
_phantom: PhantomData,
}
}
pub fn new_from_config(config: &CircuitConfig, vec_size: usize) -> Self {
let num_copies = Self::max_num_copies(config.num_routed_wires, config.num_wires, vec_size);
Self::new(num_copies, vec_size)
pub fn new_from_config(config: &CircuitConfig, bits: usize) -> Self {
let vec_size = 1 << bits;
// Need `(2 + vec_size) * num_copies` routed wires
let max_copies = (config.num_routed_wires / (2 + vec_size)).min(
// Need `(2 + vec_size + bits) * num_copies` wires
config.num_wires / (2 + vec_size + bits),
);
Self::new(max_copies, bits)
}
pub fn max_num_copies(num_routed_wires: usize, num_wires: usize, vec_size: usize) -> usize {
// Need `(2 + vec_size) * num_copies` routed wires
(num_routed_wires / (2 + vec_size)).min(
// Need `(2 + 3*vec_size) * num_copies` wires
num_wires / (2 + 3 * vec_size),
)
fn vec_size(&self) -> usize {
1 << self.bits
}
pub fn wire_access_index(&self, copy: usize) -> usize {
debug_assert!(copy < self.num_copies);
(2 + self.vec_size) * copy
(2 + self.vec_size()) * copy
}
pub fn wire_claimed_element(&self, copy: usize) -> usize {
debug_assert!(copy < self.num_copies);
(2 + self.vec_size) * copy + 1
(2 + self.vec_size()) * copy + 1
}
pub fn wire_list_item(&self, i: usize, copy: usize) -> usize {
debug_assert!(i < self.vec_size);
debug_assert!(i < self.vec_size());
debug_assert!(copy < self.num_copies);
(2 + self.vec_size) * copy + 2 + i
(2 + self.vec_size()) * copy + 2 + i
}
fn start_of_intermediate_wires(&self) -> usize {
(2 + self.vec_size) * self.num_copies
(2 + self.vec_size()) * self.num_copies
}
pub(crate) fn num_routed_wires(&self) -> usize {
self.start_of_intermediate_wires()
}
/// An intermediate wire for a dummy variable used to show equality.
/// The prover sets this to 1/(x-y) if x != y, or to an arbitrary value if
/// x == y.
pub fn wire_equality_dummy_for_index(&self, i: usize, copy: usize) -> usize {
debug_assert!(i < self.vec_size);
/// An intermediate wire where the prover gives the (purported) binary decomposition of the
/// index.
pub fn wire_bit(&self, i: usize, copy: usize) -> usize {
debug_assert!(i < self.bits);
debug_assert!(copy < self.num_copies);
self.start_of_intermediate_wires() + copy * self.vec_size + i
}
/// An intermediate wire for the "index_matches" variable (1 if the current index is the index at
/// which to compare, 0 otherwise).
pub fn wire_index_matches_for_index(&self, i: usize, copy: usize) -> usize {
debug_assert!(i < self.vec_size);
debug_assert!(copy < self.num_copies);
self.start_of_intermediate_wires()
+ self.vec_size * self.num_copies
+ self.vec_size * copy
+ i
self.start_of_intermediate_wires() + copy * self.bits + i
}
}
@ -97,23 +88,38 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
for copy in 0..self.num_copies {
let access_index = vars.local_wires[self.wire_access_index(copy)];
let list_items = (0..self.vec_size)
let mut list_items = (0..self.vec_size())
.map(|i| vars.local_wires[self.wire_list_item(i, copy)])
.collect::<Vec<_>>();
let claimed_element = vars.local_wires[self.wire_claimed_element(copy)];
let bits = (0..self.bits)
.map(|i| vars.local_wires[self.wire_bit(i, copy)])
.collect::<Vec<_>>();
for i in 0..self.vec_size {
let cur_index = F::Extension::from_canonical_usize(i);
let difference = cur_index - access_index;
let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i, copy)];
let index_matches = vars.local_wires[self.wire_index_matches_for_index(i, copy)];
// The two index equality constraints.
constraints.push(difference * equality_dummy - (F::Extension::ONE - index_matches));
constraints.push(index_matches * difference);
// Value equality constraint.
constraints.push((list_items[i] - claimed_element) * index_matches);
// Assert that each bit wire value is indeed boolean.
for &b in &bits {
constraints.push(b * (b - F::Extension::ONE));
}
// Assert that the binary decomposition was correct.
let reconstructed_index = bits
.iter()
.rev()
.fold(F::Extension::ZERO, |acc, &b| acc.double() + b);
constraints.push(reconstructed_index - access_index);
// Repeatedly fold the list, selecting the left or right item from each pair based on
// the corresponding bit.
for b in bits {
list_items = list_items
.iter()
.tuples()
.map(|(&x, &y)| x + b * (y - x))
.collect()
}
debug_assert_eq!(list_items.len(), 1);
constraints.push(list_items[0] - claimed_element);
}
constraints
@ -124,23 +130,35 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
for copy in 0..self.num_copies {
let access_index = vars.local_wires[self.wire_access_index(copy)];
let list_items = (0..self.vec_size)
let mut list_items = (0..self.vec_size())
.map(|i| vars.local_wires[self.wire_list_item(i, copy)])
.collect::<Vec<_>>();
let claimed_element = vars.local_wires[self.wire_claimed_element(copy)];
let bits = (0..self.bits)
.map(|i| vars.local_wires[self.wire_bit(i, copy)])
.collect::<Vec<_>>();
for i in 0..self.vec_size {
let cur_index = F::from_canonical_usize(i);
let difference = cur_index - access_index;
let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i, copy)];
let index_matches = vars.local_wires[self.wire_index_matches_for_index(i, copy)];
// The two index equality constraints.
constraints.push(difference * equality_dummy - (F::ONE - index_matches));
constraints.push(index_matches * difference);
// Value equality constraint.
constraints.push((list_items[i] - claimed_element) * index_matches);
// Assert that each bit wire value is indeed boolean.
for &b in &bits {
constraints.push(b * (b - F::ONE));
}
// Assert that the binary decomposition was correct.
let reconstructed_index = bits.iter().rev().fold(F::ZERO, |acc, &b| acc.double() + b);
constraints.push(reconstructed_index - access_index);
// Repeatedly fold the list, selecting the left or right item from each pair based on
// the corresponding bit.
for b in bits {
list_items = list_items
.iter()
.tuples()
.map(|(&x, &y)| x + b * (y - x))
.collect()
}
debug_assert_eq!(list_items.len(), 1);
constraints.push(list_items[0] - claimed_element);
}
constraints
@ -151,36 +169,44 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let zero = builder.zero_extension();
let two = builder.two_extension();
let mut constraints = Vec::with_capacity(self.num_constraints());
for copy in 0..self.num_copies {
let access_index = vars.local_wires[self.wire_access_index(copy)];
let list_items = (0..self.vec_size)
let mut list_items = (0..self.vec_size())
.map(|i| vars.local_wires[self.wire_list_item(i, copy)])
.collect::<Vec<_>>();
let claimed_element = vars.local_wires[self.wire_claimed_element(copy)];
let bits = (0..self.bits)
.map(|i| vars.local_wires[self.wire_bit(i, copy)])
.collect::<Vec<_>>();
for i in 0..self.vec_size {
let cur_index_ext = F::Extension::from_canonical_usize(i);
let cur_index = builder.constant_extension(cur_index_ext);
let difference = builder.sub_extension(cur_index, access_index);
let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i, copy)];
let index_matches = vars.local_wires[self.wire_index_matches_for_index(i, copy)];
let one = builder.one_extension();
let not_index_matches = builder.sub_extension(one, index_matches);
let first_equality_constraint =
builder.mul_sub_extension(difference, equality_dummy, not_index_matches);
constraints.push(first_equality_constraint);
let second_equality_constraint = builder.mul_extension(index_matches, difference);
constraints.push(second_equality_constraint);
// Output constraint.
let diff = builder.sub_extension(list_items[i], claimed_element);
let conditional_diff = builder.mul_extension(index_matches, diff);
constraints.push(conditional_diff);
// Assert that each bit wire value is indeed boolean.
for &b in &bits {
constraints.push(builder.mul_sub_extension(b, b, b));
}
// Assert that the binary decomposition was correct.
let reconstructed_index = bits
.iter()
.rev()
.fold(zero, |acc, &b| builder.mul_add_extension(acc, two, b));
constraints.push(builder.sub_extension(reconstructed_index, access_index));
// Repeatedly fold the list, selecting the left or right item from each pair based on
// the corresponding bit.
for b in bits {
list_items = list_items
.iter()
.tuples()
.map(|(&x, &y)| builder.select_ext_generalized(b, y, x))
.collect()
}
debug_assert_eq!(list_items.len(), 1);
constraints.push(builder.sub_extension(list_items[0], claimed_element));
}
constraints
@ -207,7 +233,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
}
fn num_wires(&self) -> usize {
self.wire_index_matches_for_index(self.vec_size - 1, self.num_copies - 1) + 1
self.wire_bit(self.bits - 1, self.num_copies - 1) + 1
}
fn num_constants(&self) -> usize {
@ -215,11 +241,12 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
}
fn degree(&self) -> usize {
2
self.bits + 1
}
fn num_constraints(&self) -> usize {
3 * self.num_copies * self.vec_size
let constraints_per_copy = self.bits + 2;
self.num_copies * constraints_per_copy
}
}
@ -234,10 +261,8 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for RandomAccessGenera
fn dependencies(&self) -> Vec<Target> {
let local_target = |input| Target::wire(self.gate_index, input);
let mut deps = Vec::new();
deps.push(local_target(self.gate.wire_access_index(self.copy)));
deps.push(local_target(self.gate.wire_claimed_element(self.copy)));
for i in 0..self.gate.vec_size {
let mut deps = vec![local_target(self.gate.wire_access_index(self.copy))];
for i in 0..self.gate.vec_size() {
deps.push(local_target(self.gate.wire_list_item(i, self.copy)));
}
deps
@ -250,11 +275,12 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for RandomAccessGenera
};
let get_local_wire = |input| witness.get_wire(local_wire(input));
let mut set_local_wire = |input, value| out_buffer.set_wire(local_wire(input), value);
// Compute the new vector and the values for equality_dummy and index_matches
let vec_size = self.gate.vec_size;
let access_index_f = get_local_wire(self.gate.wire_access_index(self.copy));
let copy = self.copy;
let vec_size = self.gate.vec_size();
let access_index_f = get_local_wire(self.gate.wire_access_index(copy));
let access_index = access_index_f.to_canonical_u64() as usize;
debug_assert!(
access_index < vec_size,
@ -263,22 +289,14 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for RandomAccessGenera
vec_size
);
for i in 0..vec_size {
let equality_dummy_wire =
local_wire(self.gate.wire_equality_dummy_for_index(i, self.copy));
let index_matches_wire =
local_wire(self.gate.wire_index_matches_for_index(i, self.copy));
set_local_wire(
self.gate.wire_claimed_element(copy),
get_local_wire(self.gate.wire_list_item(access_index, copy)),
);
if i == access_index {
out_buffer.set_wire(equality_dummy_wire, F::ONE);
out_buffer.set_wire(index_matches_wire, F::ONE);
} else {
out_buffer.set_wire(
equality_dummy_wire,
(F::from_canonical_usize(i) - F::from_canonical_usize(access_index)).inverse(),
);
out_buffer.set_wire(index_matches_wire, F::ZERO);
}
for i in 0..self.gate.bits {
let bit = F::from_bool(((access_index >> i) & 1) != 0);
set_local_wire(self.gate.wire_bit(i, copy), bit);
}
}
}
@ -322,6 +340,7 @@ mod tests {
/// Returns the local wires for a random access gate given the vectors, elements to compare,
/// and indices.
fn get_wires(
bits: usize,
lists: Vec<Vec<F>>,
access_indices: Vec<usize>,
claimed_elements: Vec<F>,
@ -330,8 +349,7 @@ mod tests {
let vec_size = lists[0].len();
let mut v = Vec::new();
let mut equality_dummy_vals = Vec::new();
let mut index_matches_vals = Vec::new();
let mut bit_vals = Vec::new();
for copy in 0..num_copies {
let access_index = access_indices[copy];
v.push(F::from_canonical_usize(access_index));
@ -340,26 +358,17 @@ mod tests {
v.push(lists[copy][j]);
}
for i in 0..vec_size {
if i == access_index {
equality_dummy_vals.push(F::ONE);
index_matches_vals.push(F::ONE);
} else {
equality_dummy_vals.push(
(F::from_canonical_usize(i) - F::from_canonical_usize(access_index))
.inverse(),
);
index_matches_vals.push(F::ZERO);
}
for i in 0..bits {
bit_vals.push(F::from_bool(((access_index >> i) & 1) != 0));
}
}
v.extend(equality_dummy_vals);
v.extend(index_matches_vals);
v.extend(bit_vals);
v.iter().map(|&x| x.into()).collect::<Vec<_>>()
v.iter().map(|&x| x.into()).collect()
}
let vec_size = 3;
let bits = 3;
let vec_size = 1 << bits;
let num_copies = 4;
let lists = (0..num_copies)
.map(|_| F::rand_vec(vec_size))
@ -368,7 +377,7 @@ mod tests {
.map(|_| thread_rng().gen_range(0..vec_size))
.collect::<Vec<_>>();
let gate = RandomAccessGate::<F, D> {
vec_size,
bits,
num_copies,
_phantom: PhantomData,
};
@ -380,13 +389,18 @@ mod tests {
.collect();
let good_vars = EvaluationVars {
local_constants: &[],
local_wires: &get_wires(lists.clone(), access_indices.clone(), good_claimed_elements),
local_wires: &get_wires(
bits,
lists.clone(),
access_indices.clone(),
good_claimed_elements,
),
public_inputs_hash: &HashOut::rand(),
};
let bad_claimed_elements = F::rand_vec(4);
let bad_vars = EvaluationVars {
local_constants: &[],
local_wires: &get_wires(lists, access_indices, bad_claimed_elements),
local_wires: &get_wires(bits, lists, access_indices, bad_claimed_elements),
public_inputs_hash: &HashOut::rand(),
};

View File

@ -0,0 +1,222 @@
use std::ops::Range;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::extension_field::FieldExtension;
use crate::field::field_types::RichField;
use crate::gates::gate::Gate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// Computes `sum alpha^i c_i` for a vector `c_i` of `num_coeffs` elements of the extension field.
#[derive(Debug, Clone)]
pub struct ReducingExtensionGate<const D: usize> {
pub num_coeffs: usize,
}
impl<const D: usize> ReducingExtensionGate<D> {
pub fn new(num_coeffs: usize) -> Self {
Self { num_coeffs }
}
pub fn max_coeffs_len(num_wires: usize, num_routed_wires: usize) -> usize {
// `3*D` routed wires are used for the output, alpha and old accumulator.
// Need `num_coeffs*D` routed wires for coeffs, and `(num_coeffs-1)*D` wires for accumulators.
((num_routed_wires - 3 * D) / D).min((num_wires - 2 * D) / (D * 2))
}
pub fn wires_output() -> Range<usize> {
0..D
}
pub fn wires_alpha() -> Range<usize> {
D..2 * D
}
pub fn wires_old_acc() -> Range<usize> {
2 * D..3 * D
}
const START_COEFFS: usize = 3 * D;
pub fn wires_coeff(i: usize) -> Range<usize> {
Self::START_COEFFS + i * D..Self::START_COEFFS + (i + 1) * D
}
fn start_accs(&self) -> usize {
Self::START_COEFFS + self.num_coeffs * D
}
fn wires_accs(&self, i: usize) -> Range<usize> {
debug_assert!(i < self.num_coeffs);
if i == self.num_coeffs - 1 {
// The last accumulator is the output.
return Self::wires_output();
}
self.start_accs() + D * i..self.start_accs() + D * (i + 1)
}
}
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ReducingExtensionGate<D> {
fn id(&self) -> String {
format!("{:?}", self)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let alpha = vars.get_local_ext_algebra(Self::wires_alpha());
let old_acc = vars.get_local_ext_algebra(Self::wires_old_acc());
let coeffs = (0..self.num_coeffs)
.map(|i| vars.get_local_ext_algebra(Self::wires_coeff(i)))
.collect::<Vec<_>>();
let accs = (0..self.num_coeffs)
.map(|i| vars.get_local_ext_algebra(self.wires_accs(i)))
.collect::<Vec<_>>();
let mut constraints = Vec::with_capacity(<Self as Gate<F, D>>::num_constraints(self));
let mut acc = old_acc;
for i in 0..self.num_coeffs {
constraints.push(acc * alpha + coeffs[i] - accs[i]);
acc = accs[i];
}
constraints
.into_iter()
.flat_map(|alg| alg.to_basefield_array())
.collect()
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let alpha = vars.get_local_ext(Self::wires_alpha());
let old_acc = vars.get_local_ext(Self::wires_old_acc());
let coeffs = (0..self.num_coeffs)
.map(|i| vars.get_local_ext(Self::wires_coeff(i)))
.collect::<Vec<_>>();
let accs = (0..self.num_coeffs)
.map(|i| vars.get_local_ext(self.wires_accs(i)))
.collect::<Vec<_>>();
let mut constraints = Vec::with_capacity(<Self as Gate<F, D>>::num_constraints(self));
let mut acc = old_acc;
for i in 0..self.num_coeffs {
constraints.extend((acc * alpha + coeffs[i] - accs[i]).to_basefield_array());
acc = accs[i];
}
constraints
}
fn eval_unfiltered_recursively(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let alpha = vars.get_local_ext_algebra(Self::wires_alpha());
let old_acc = vars.get_local_ext_algebra(Self::wires_old_acc());
let coeffs = (0..self.num_coeffs)
.map(|i| vars.get_local_ext_algebra(Self::wires_coeff(i)))
.collect::<Vec<_>>();
let accs = (0..self.num_coeffs)
.map(|i| vars.get_local_ext_algebra(self.wires_accs(i)))
.collect::<Vec<_>>();
let mut constraints = Vec::with_capacity(<Self as Gate<F, D>>::num_constraints(self));
let mut acc = old_acc;
for i in 0..self.num_coeffs {
let coeff = coeffs[i];
let mut tmp = builder.mul_add_ext_algebra(acc, alpha, coeff);
tmp = builder.sub_ext_algebra(tmp, accs[i]);
constraints.push(tmp);
acc = accs[i];
}
constraints
.into_iter()
.flat_map(|alg| alg.to_ext_target_array())
.collect()
}
fn generators(
&self,
gate_index: usize,
_local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
vec![Box::new(
ReducingGenerator {
gate_index,
gate: self.clone(),
}
.adapter(),
)]
}
fn num_wires(&self) -> usize {
2 * D + 2 * D * self.num_coeffs
}
fn num_constants(&self) -> usize {
0
}
fn degree(&self) -> usize {
2
}
fn num_constraints(&self) -> usize {
D * self.num_coeffs
}
}
#[derive(Debug)]
struct ReducingGenerator<const D: usize> {
gate_index: usize,
gate: ReducingExtensionGate<D>,
}
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ReducingGenerator<D> {
fn dependencies(&self) -> Vec<Target> {
ReducingExtensionGate::<D>::wires_alpha()
.chain(ReducingExtensionGate::<D>::wires_old_acc())
.chain((0..self.gate.num_coeffs).flat_map(ReducingExtensionGate::<D>::wires_coeff))
.map(|i| Target::wire(self.gate_index, i))
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let local_extension = |range: Range<usize>| -> F::Extension {
let t = ExtensionTarget::from_range(self.gate_index, range);
witness.get_extension_target(t)
};
let alpha = local_extension(ReducingExtensionGate::<D>::wires_alpha());
let old_acc = local_extension(ReducingExtensionGate::<D>::wires_old_acc());
let coeffs = (0..self.gate.num_coeffs)
.map(|i| local_extension(ReducingExtensionGate::<D>::wires_coeff(i)))
.collect::<Vec<_>>();
let accs = (0..self.gate.num_coeffs)
.map(|i| ExtensionTarget::from_range(self.gate_index, self.gate.wires_accs(i)))
.collect::<Vec<_>>();
let mut acc = old_acc;
for i in 0..self.gate.num_coeffs {
let computed_acc = acc * alpha + coeffs[i];
out_buffer.set_extension_target(accs[i], computed_acc);
acc = computed_acc;
}
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::gates::reducing_extension::ReducingExtensionGate;
#[test]
fn low_degree() {
test_low_degree::<GoldilocksField, _, 4>(ReducingExtensionGate::new(22));
}
#[test]
fn eval_fns() -> Result<()> {
test_eval_fns::<GoldilocksField, _, 4>(ReducingExtensionGate::new(22))
}
}

View File

@ -0,0 +1,423 @@
use std::marker::PhantomData;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, RichField};
use crate::gates::gate::Gate;
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
use crate::iop::witness::{PartitionWitness, Witness};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// A gate to perform a subtraction on 32-bit limbs: given `x`, `y`, and `borrow`, it returns
/// the result `x - y - borrow` and, if this underflows, a new `borrow`. Inputs are not range-checked.
#[derive(Copy, Clone, Debug)]
pub struct U32SubtractionGate<F: RichField + Extendable<D>, const D: usize> {
pub num_ops: usize,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize> U32SubtractionGate<F, D> {
pub fn new_from_config(config: &CircuitConfig) -> Self {
Self {
num_ops: Self::num_ops(config),
_phantom: PhantomData,
}
}
pub(crate) fn num_ops(config: &CircuitConfig) -> usize {
let wires_per_op = 5 + Self::num_limbs();
let routed_wires_per_op = 5;
(config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op)
}
pub fn wire_ith_input_x(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i
}
pub fn wire_ith_input_y(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i + 1
}
pub fn wire_ith_input_borrow(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i + 2
}
pub fn wire_ith_output_result(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i + 3
}
pub fn wire_ith_output_borrow(&self, i: usize) -> usize {
debug_assert!(i < self.num_ops);
5 * i + 4
}
pub fn limb_bits() -> usize {
2
}
// We have limbs for the 32 bits of `output_result`.
pub fn num_limbs() -> usize {
32 / Self::limb_bits()
}
pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize {
debug_assert!(i < self.num_ops);
debug_assert!(j < Self::num_limbs());
5 * self.num_ops + Self::num_limbs() * i + j
}
}
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32SubtractionGate<F, D> {
fn id(&self) -> String {
format!("{:?}", self)
}
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let mut constraints = Vec::with_capacity(self.num_constraints());
for i in 0..self.num_ops {
let input_x = vars.local_wires[self.wire_ith_input_x(i)];
let input_y = vars.local_wires[self.wire_ith_input_y(i)];
let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)];
let result_initial = input_x - input_y - input_borrow;
let base = F::Extension::from_canonical_u64(1 << 32u64);
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)];
constraints.push(output_result - (result_initial + base * output_borrow));
// Range-check output_result to be at most 32 bits.
let mut combined_limbs = F::Extension::ZERO;
let limb_base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits());
for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits();
let product = (0..max_limb)
.map(|x| this_limb - F::Extension::from_canonical_usize(x))
.product();
constraints.push(product);
combined_limbs = limb_base * combined_limbs + this_limb;
}
constraints.push(combined_limbs - output_result);
// Range-check output_borrow to be one bit.
constraints.push(output_borrow * (F::Extension::ONE - output_borrow));
}
constraints
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints());
for i in 0..self.num_ops {
let input_x = vars.local_wires[self.wire_ith_input_x(i)];
let input_y = vars.local_wires[self.wire_ith_input_y(i)];
let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)];
let result_initial = input_x - input_y - input_borrow;
let base = F::from_canonical_u64(1 << 32u64);
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)];
constraints.push(output_result - (result_initial + base * output_borrow));
// Range-check output_result to be at most 32 bits.
let mut combined_limbs = F::ZERO;
let limb_base = F::from_canonical_u64(1u64 << Self::limb_bits());
for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits();
let product = (0..max_limb)
.map(|x| this_limb - F::from_canonical_usize(x))
.product();
constraints.push(product);
combined_limbs = limb_base * combined_limbs + this_limb;
}
constraints.push(combined_limbs - output_result);
// Range-check output_borrow to be one bit.
constraints.push(output_borrow * (F::ONE - output_borrow));
}
constraints
}
fn eval_unfiltered_recursively(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let mut constraints = Vec::with_capacity(self.num_constraints());
for i in 0..self.num_ops {
let input_x = vars.local_wires[self.wire_ith_input_x(i)];
let input_y = vars.local_wires[self.wire_ith_input_y(i)];
let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)];
let diff = builder.sub_extension(input_x, input_y);
let result_initial = builder.sub_extension(diff, input_borrow);
let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << 32u64));
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)];
let computed_output = builder.mul_add_extension(base, output_borrow, result_initial);
constraints.push(builder.sub_extension(output_result, computed_output));
// Range-check output_result to be at most 32 bits.
let mut combined_limbs = builder.zero_extension();
let limb_base = builder
.constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits()));
for j in (0..Self::num_limbs()).rev() {
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
let max_limb = 1 << Self::limb_bits();
let mut product = builder.one_extension();
for x in 0..max_limb {
let x_target =
builder.constant_extension(F::Extension::from_canonical_usize(x));
let diff = builder.sub_extension(this_limb, x_target);
product = builder.mul_extension(product, diff);
}
constraints.push(product);
combined_limbs = builder.mul_add_extension(limb_base, combined_limbs, this_limb);
}
constraints.push(builder.sub_extension(combined_limbs, output_result));
// Range-check output_borrow to be one bit.
let one = builder.one_extension();
let not_borrow = builder.sub_extension(one, output_borrow);
constraints.push(builder.mul_extension(output_borrow, not_borrow));
}
constraints
}
fn generators(
&self,
gate_index: usize,
_local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
(0..self.num_ops)
.map(|i| {
let g: Box<dyn WitnessGenerator<F>> = Box::new(
U32SubtractionGenerator {
gate: *self,
gate_index,
i,
_phantom: PhantomData,
}
.adapter(),
);
g
})
.collect()
}
fn num_wires(&self) -> usize {
self.num_ops * (5 + Self::num_limbs())
}
fn num_constants(&self) -> usize {
0
}
fn degree(&self) -> usize {
1 << Self::limb_bits()
}
fn num_constraints(&self) -> usize {
self.num_ops * (3 + Self::num_limbs())
}
}
#[derive(Clone, Debug)]
struct U32SubtractionGenerator<F: RichField + Extendable<D>, const D: usize> {
gate: U32SubtractionGate<F, D>,
gate_index: usize,
i: usize,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
for U32SubtractionGenerator<F, D>
{
fn dependencies(&self) -> Vec<Target> {
let local_target = |input| Target::wire(self.gate_index, input);
vec![
local_target(self.gate.wire_ith_input_x(self.i)),
local_target(self.gate.wire_ith_input_y(self.i)),
local_target(self.gate.wire_ith_input_borrow(self.i)),
]
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let local_wire = |input| Wire {
gate: self.gate_index,
input,
};
let get_local_wire = |input| witness.get_wire(local_wire(input));
let input_x = get_local_wire(self.gate.wire_ith_input_x(self.i));
let input_y = get_local_wire(self.gate.wire_ith_input_y(self.i));
let input_borrow = get_local_wire(self.gate.wire_ith_input_borrow(self.i));
let result_initial = input_x - input_y - input_borrow;
let result_initial_u64 = result_initial.to_canonical_u64();
let output_borrow = if result_initial_u64 > 1 << 32u64 {
F::ONE
} else {
F::ZERO
};
let base = F::from_canonical_u64(1 << 32u64);
let output_result = result_initial + base * output_borrow;
let output_result_wire = local_wire(self.gate.wire_ith_output_result(self.i));
let output_borrow_wire = local_wire(self.gate.wire_ith_output_borrow(self.i));
out_buffer.set_wire(output_result_wire, output_result);
out_buffer.set_wire(output_borrow_wire, output_borrow);
let output_result_u64 = output_result.to_canonical_u64();
let num_limbs = U32SubtractionGate::<F, D>::num_limbs();
let limb_base = 1 << U32SubtractionGate::<F, D>::limb_bits();
let output_limbs: Vec<_> = (0..num_limbs)
.scan(output_result_u64, |acc, _| {
let tmp = *acc % limb_base;
*acc /= limb_base;
Some(F::from_canonical_u64(tmp))
})
.collect();
for j in 0..num_limbs {
let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j));
out_buffer.set_wire(wire, output_limbs[j]);
}
}
}
#[cfg(test)]
mod tests {
use std::marker::PhantomData;
use anyhow::Result;
use rand::Rng;
use crate::field::extension_field::quartic::QuarticExtension;
use crate::field::field_types::{Field, PrimeField};
use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::gate::Gate;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::gates::subtraction_u32::U32SubtractionGate;
use crate::hash::hash_types::HashOut;
use crate::plonk::vars::EvaluationVars;
#[test]
fn low_degree() {
test_low_degree::<GoldilocksField, _, 4>(U32SubtractionGate::<GoldilocksField, 4> {
num_ops: 3,
_phantom: PhantomData,
})
}
#[test]
fn eval_fns() -> Result<()> {
test_eval_fns::<GoldilocksField, _, 4>(U32SubtractionGate::<GoldilocksField, 4> {
num_ops: 3,
_phantom: PhantomData,
})
}
#[test]
fn test_gate_constraint() {
type F = GoldilocksField;
type FF = QuarticExtension<GoldilocksField>;
const D: usize = 4;
const NUM_U32_SUBTRACTION_OPS: usize = 3;
fn get_wires(inputs_x: Vec<u64>, inputs_y: Vec<u64>, borrows: Vec<u64>) -> Vec<FF> {
let mut v0 = Vec::new();
let mut v1 = Vec::new();
let limb_bits = U32SubtractionGate::<F, D>::limb_bits();
let num_limbs = U32SubtractionGate::<F, D>::num_limbs();
let limb_base = 1 << limb_bits;
for c in 0..NUM_U32_SUBTRACTION_OPS {
let input_x = F::from_canonical_u64(inputs_x[c]);
let input_y = F::from_canonical_u64(inputs_y[c]);
let input_borrow = F::from_canonical_u64(borrows[c]);
let result_initial = input_x - input_y - input_borrow;
let result_initial_u64 = result_initial.to_canonical_u64();
let output_borrow = if result_initial_u64 > 1 << 32u64 {
F::ONE
} else {
F::ZERO
};
let base = F::from_canonical_u64(1 << 32u64);
let output_result = result_initial + base * output_borrow;
let output_result_u64 = output_result.to_canonical_u64();
let mut output_limbs: Vec<_> = (0..num_limbs)
.scan(output_result_u64, |acc, _| {
let tmp = *acc % limb_base;
*acc /= limb_base;
Some(F::from_canonical_u64(tmp))
})
.collect();
v0.push(input_x);
v0.push(input_y);
v0.push(input_borrow);
v0.push(output_result);
v0.push(output_borrow);
v1.append(&mut output_limbs);
}
v0.iter()
.chain(v1.iter())
.map(|&x| x.into())
.collect::<Vec<_>>()
}
let mut rng = rand::thread_rng();
let inputs_x = (0..NUM_U32_SUBTRACTION_OPS)
.map(|_| rng.gen::<u32>() as u64)
.collect();
let inputs_y = (0..NUM_U32_SUBTRACTION_OPS)
.map(|_| rng.gen::<u32>() as u64)
.collect();
let borrows = (0..NUM_U32_SUBTRACTION_OPS)
.map(|_| (rng.gen::<u32>() % 2) as u64)
.collect();
let gate = U32SubtractionGate::<F, D> {
num_ops: NUM_U32_SUBTRACTION_OPS,
_phantom: PhantomData,
};
let vars = EvaluationVars {
local_constants: &[],
local_wires: &get_wires(inputs_x, inputs_y, borrows),
public_inputs_hash: &HashOut::rand(),
};
assert!(
gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()),
"Gate constraints are not satisfied."
);
}
}

View File

@ -1,5 +1,7 @@
#![allow(clippy::assertions_on_constants)]
use std::arch::aarch64::*;
use std::convert::TryInto;
use std::arch::asm;
use static_assertions::const_assert;
use unroll::unroll_for_loops;
@ -172,9 +174,7 @@ unsafe fn multiply(x: u64, y: u64) -> u64 {
let xy_hi_lo_mul_epsilon = mul_epsilon(xy_hi);
// add_with_wraparound is safe, as xy_hi_lo_mul_epsilon <= 0xfffffffe00000001 <= ORDER.
let res1 = add_with_wraparound(res0, xy_hi_lo_mul_epsilon);
res1
add_with_wraparound(res0, xy_hi_lo_mul_epsilon)
}
// ==================================== STANDALONE CONST LAYER =====================================
@ -267,9 +267,7 @@ unsafe fn mds_reduce(
// Multiply by EPSILON and accumulate.
let res_unadj = vmlal_laneq_u32::<0>(res_lo, res_hi_hi, mds_consts0);
let res_adj = vcgtq_u64(res_lo, res_unadj);
let res = vsraq_n_u64::<32>(res_unadj, res_adj);
res
vsraq_n_u64::<32>(res_unadj, res_adj)
}
#[inline(always)]
@ -969,8 +967,7 @@ unsafe fn partial_round(
#[inline(always)]
unsafe fn full_round(state: [u64; 12], round_constants: &[u64; WIDTH]) -> [u64; 12] {
let state = sbox_layer_full(state);
let state = mds_const_layers_full(state, round_constants);
state
mds_const_layers_full(state, round_constants)
}
#[inline]

View File

@ -1,5 +1,4 @@
use core::arch::x86_64::*;
use std::convert::TryInto;
use std::mem::size_of;
use static_assertions::const_assert;

View File

@ -1,5 +1,3 @@
use std::convert::TryInto;
use rand::Rng;
use serde::{Deserialize, Deserializer, Serialize, Serializer};

View File

@ -1,7 +1,5 @@
//! Concrete instantiation of a hash function.
use std::convert::TryInto;
use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField;
use crate::hash::hash_types::{HashOut, HashOutTarget};
@ -131,10 +129,8 @@ pub fn hash_n_to_m<F: RichField, P: PlonkyPermutation<F>>(
// Absorb all input chunks.
for input_chunk in inputs.chunks(SPONGE_RATE) {
for i in 0..input_chunk.len() {
state[i] = input_chunk[i];
}
state = P::permute(state);
state[..input_chunk.len()].copy_from_slice(input_chunk);
state = permute(state);
}
// Squeeze until we have the desired number of outputs.

View File

@ -1,5 +1,3 @@
use std::convert::TryInto;
use anyhow::{ensure, Result};
use serde::{Deserialize, Serialize};
@ -55,6 +53,7 @@ pub(crate) fn verify_merkle_proof<F: RichField, H: Hasher<F>>(
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Verifies that the given leaf data is present at the given index in the Merkle tree with the
/// given cap. The index is given by it's little-endian bits.
#[cfg(test)]
pub(crate) fn verify_merkle_proof<H: AlgebraicHasher<F>>(
&mut self,
leaf_data: Vec<Target>,
@ -94,7 +93,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
proof: &MerkleProofTarget,
) {
let zero = self.zero();
let mut state: HashOutTarget = self.hash_or_noop::<H>(leaf_data);
let mut state:HashOutTarget = self.hash_or_noop(leaf_data);
for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) {
let mut perm_inputs = [zero; SPONGE_WIDTH];
@ -116,7 +115,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
pub fn assert_hashes_equal(&mut self, x: HashOutTarget, y: HashOutTarget) {
pub fn connect_hashes(&mut self, x: HashOutTarget, y: HashOutTarget) {
for i in 0..4 {
self.connect(x.elements[i], y.elements[i]);
}

View File

@ -1,8 +1,6 @@
//! Implementation of the Poseidon hash function, as described in
//! https://eprint.iacr.org/2019/458.pdf
use std::convert::TryInto;
use unroll::unroll_for_loops;
use crate::field::extension_field::target::ExtensionTarget;
@ -452,9 +450,10 @@ pub trait Poseidon: PrimeField {
s0,
);
for i in 1..WIDTH {
let t = <Self as Poseidon>::FAST_PARTIAL_ROUND_W_HATS[r][i - 1];
let t = Self::from_canonical_u64(t);
d = builder.mul_const_add_extension(t, state[i], d);
let t = <Self as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_W_HATS[r][i - 1];
let t = Self::Extension::from_canonical_u64(t);
let t = builder.constant_extension(t);
d = builder.mul_add_extension(t, state[i], d);
}
let mut result = [builder.zero_extension(); WIDTH];
@ -624,6 +623,7 @@ pub trait Poseidon: PrimeField {
}
}
#[cfg(test)]
pub(crate) mod test_helpers {
use crate::field::field_types::Field;
use crate::hash::hashing::SPONGE_WIDTH;

View File

@ -1,7 +1,7 @@
use crate::field::extension_field::target::ExtensionTarget;
use std::convert::TryInto;
use std::marker::PhantomData;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::RichField;
use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget};

View File

@ -1,9 +1,15 @@
use std::fmt::Debug;
use std::marker::PhantomData;
use num::BigUint;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::Field;
use crate::field::field_types::{Field, RichField};
use crate::gadgets::arithmetic_u32::U32Target;
use crate::gadgets::biguint::BigUintTarget;
use crate::gadgets::nonnative::NonNativeTarget;
use crate::hash::hash_types::{HashOut, HashOutTarget};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
@ -89,7 +95,7 @@ pub(crate) fn generate_partial_witness<
assert_eq!(
remaining_generators, 0,
"{} generators weren't run",
remaining_generators
remaining_generators,
);
witness
@ -156,6 +162,24 @@ impl<F: Field> GeneratedValues<F> {
self.target_values.push((target, value))
}
fn set_u32_target(&mut self, target: U32Target, value: u32) {
self.set_target(target.0, F::from_canonical_u32(value))
}
pub fn set_biguint_target(&mut self, target: BigUintTarget, value: BigUint) {
let mut limbs = value.to_u32_digits();
assert!(target.num_limbs() >= limbs.len());
limbs.resize(target.num_limbs(), 0);
for i in 0..target.num_limbs() {
self.set_u32_target(target.get_limb(i), limbs[i]);
}
}
pub fn set_nonnative_target<FF: Field>(&mut self, target: NonNativeTarget<FF>, value: FF) {
self.set_biguint_target(target.value, value.to_biguint())
}
pub fn set_hash_target(&mut self, ht: HashOutTarget, value: HashOut<F>) {
ht.elements
.iter()

View File

@ -41,6 +41,7 @@ impl Target {
/// A `Target` which has already been constrained such that it can only be 0 or 1.
#[derive(Copy, Clone, Debug)]
#[allow(clippy::manual_non_exhaustive)]
pub struct BoolTarget {
pub target: Target,
/// This private field is here to force all instantiations to go through `new_unsafe`.

View File

@ -1,9 +1,14 @@
use std::collections::HashMap;
use std::convert::TryInto;
use num::{BigUint, FromPrimitive, Zero};
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::{Field, RichField};
use crate::field::field_types::Field;
use crate::gadgets::arithmetic_u32::U32Target;
use crate::gadgets::biguint::BigUintTarget;
use crate::gadgets::nonnative::NonNativeTarget;
use crate::hash::hash_types::HashOutTarget;
use crate::hash::hash_types::{HashOut, MerkleCapTarget};
use crate::hash::merkle_tree::MerkleCap;
@ -54,6 +59,24 @@ pub trait Witness<F: Field> {
panic!("not a bool")
}
fn get_biguint_target(&self, target: BigUintTarget) -> BigUint {
let mut result = BigUint::zero();
let limb_base = BigUint::from_u64(1 << 32u64).unwrap();
for i in (0..target.num_limbs()).rev() {
let limb = target.get_limb(i);
result *= &limb_base;
result += self.get_target(limb.0).to_biguint();
}
result
}
fn get_nonnative_target<FF: Field>(&self, target: NonNativeTarget<FF>) -> FF {
let val = self.get_biguint_target(target.value);
FF::from_biguint(val)
}
fn get_hash_target(&self, ht: HashOutTarget) -> HashOut<F> {
HashOut {
elements: self.get_targets(&ht.elements).try_into().unwrap(),
@ -122,6 +145,16 @@ pub trait Witness<F: Field> {
self.set_target(target.target, F::from_bool(value))
}
fn set_u32_target(&mut self, target: U32Target, value: u32) {
self.set_target(target.0, F::from_canonical_u32(value))
}
fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) {
for (&lt, &l) in target.limbs.iter().zip(&value.to_u32_digits()) {
self.set_u32_target(lt, l);
}
}
fn set_wire(&mut self, wire: Wire, value: F) {
self.set_target(Target::Wire(wire), value)
}

View File

@ -1,9 +1,15 @@
#![feature(asm)]
#![feature(destructuring_assignment)]
#![allow(incomplete_features)]
#![allow(const_evaluatable_unchecked)]
#![allow(clippy::new_without_default)]
#![allow(clippy::too_many_arguments)]
#![allow(clippy::len_without_is_empty)]
#![allow(clippy::needless_range_loop)]
#![feature(asm_sym)]
#![feature(generic_const_exprs)]
#![feature(specialization)]
#![feature(stdsimd)]
pub mod curve;
pub mod field;
pub mod fri;
pub mod gadgets;
@ -13,3 +19,11 @@ pub mod iop;
pub mod plonk;
pub mod polynomial;
pub mod util;
// Set up Jemalloc
#[cfg(not(target_env = "msvc"))]
use jemallocator::Jemalloc;
#[cfg(not(target_env = "msvc"))]
#[global_allocator]
static GLOBAL: Jemalloc = Jemalloc;

View File

@ -1,6 +1,5 @@
use std::cmp::max;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::convert::TryInto;
use std::time::Instant;
use log::{debug, info, Level};
@ -12,14 +11,20 @@ use crate::field::fft::fft_root_table;
use crate::field::field_types::Field;
use crate::fri::commitment::PolynomialBatchCommitment;
use crate::fri::{FriConfig, FriParams};
use crate::gadgets::arithmetic_extension::ArithmeticOperation;
use crate::gates::arithmetic::ArithmeticExtensionGate;
use crate::gadgets::arithmetic::BaseArithmeticOperation;
use crate::gadgets::arithmetic_extension::ExtensionArithmeticOperation;
use crate::gadgets::arithmetic_u32::U32Target;
use crate::gates::arithmetic_base::ArithmeticGate;
use crate::gates::arithmetic_extension::ArithmeticExtensionGate;
use crate::gates::arithmetic_u32::U32ArithmeticGate;
use crate::gates::constant::ConstantGate;
use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate};
use crate::gates::gate_tree::Tree;
use crate::gates::multiplication_extension::MulExtensionGate;
use crate::gates::noop::NoopGate;
use crate::gates::public_input::PublicInputGate;
use crate::gates::random_access::RandomAccessGate;
use crate::gates::subtraction_u32::U32SubtractionGate;
use crate::gates::switch::SwitchGate;
use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget};
use crate::iop::generator::{
@ -35,7 +40,7 @@ use crate::plonk::config::{GenericConfig, Hasher};
use crate::plonk::copy_constraint::CopyConstraint;
use crate::plonk::permutation_argument::Forest;
use crate::plonk::plonk_common::PlonkPolynomials;
use crate::polynomial::polynomial::PolynomialValues;
use crate::polynomial::PolynomialValues;
use crate::util::context_tree::ContextTree;
use crate::util::marking::{Markable, MarkedTargets};
use crate::util::partial_products::num_partial_products;
@ -71,24 +76,13 @@ pub struct CircuitBuilder<F: Extendable<D>, const D: usize> {
constants_to_targets: HashMap<F, Target>,
targets_to_constants: HashMap<Target, F>,
/// Memoized results of `arithmetic` calls.
pub(crate) base_arithmetic_results: HashMap<BaseArithmeticOperation<F>, Target>,
/// Memoized results of `arithmetic_extension` calls.
pub(crate) arithmetic_results: HashMap<ArithmeticOperation<F, D>, ExtensionTarget<D>>,
pub(crate) arithmetic_results: HashMap<ExtensionArithmeticOperation<F, D>, ExtensionTarget<D>>,
/// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using
/// these constants with gate index `g` and already using `i` arithmetic operations.
pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>,
/// A map `(c0, c1) -> (g, i)` from constants `vec_size` to an available arithmetic gate using
/// these constants with gate index `g` and already using `i` random accesses.
pub(crate) free_random_access: HashMap<usize, (usize, usize)>,
// `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value
// chunk_size, and contains `(g, i, c)`, if the gate `g`, at index `i`, already contains `c` copies
// of switches
pub(crate) current_switch_gates: Vec<Option<(SwitchGate<F, D>, usize, usize)>>,
/// An available `ConstantGate` instance, if any.
free_constant: Option<(usize, usize)>,
batched_gates: BatchedGates<F, D>,
}
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
@ -104,12 +98,10 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
marked_targets: Vec::new(),
generators: Vec::new(),
constants_to_targets: HashMap::new(),
base_arithmetic_results: HashMap::new(),
arithmetic_results: HashMap::new(),
targets_to_constants: HashMap::new(),
free_arithmetic: HashMap::new(),
free_random_access: HashMap::new(),
current_switch_gates: Vec::new(),
free_constant: None,
batched_gates: BatchedGates::new(),
};
builder.check_config();
builder
@ -216,6 +208,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
gate_ref,
constants,
});
index
}
@ -260,6 +253,11 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.connect(x, zero);
}
pub fn assert_one(&mut self, x: Target) {
let one = self.one();
self.connect(x, one);
}
pub fn add_generators(&mut self, generators: Vec<Box<dyn WitnessGenerator<F>>>) {
self.generators.extend(generators);
}
@ -313,26 +311,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
target
}
/// Returns the gate index and copy index of a free `ConstantGate` slot, potentially adding a
/// new `ConstantGate` if needed.
fn constant_gate_instance(&mut self) -> (usize, usize) {
if self.free_constant.is_none() {
let num_consts = self.config.constant_gate_size;
// We will fill this `ConstantGate` with zero constants initially.
// These will be overwritten by `constant` as the gate instances are filled.
let gate = self.add_gate(ConstantGate { num_consts }, vec![F::ZERO; num_consts]);
self.free_constant = Some((gate, 0));
}
let (gate, instance) = self.free_constant.unwrap();
if instance + 1 < self.config.constant_gate_size {
self.free_constant = Some((gate, instance + 1));
} else {
self.free_constant = None;
}
(gate, instance)
}
pub fn constants(&mut self, constants: &[F]) -> Vec<Target> {
constants.iter().map(|&c| self.constant(c)).collect()
}
@ -345,6 +323,11 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
/// Returns a U32Target for the value `c`, which is assumed to be at most 32 bits.
pub fn constant_u32(&mut self, c: u32) -> U32Target {
U32Target(self.constant(F::from_canonical_u32(c)))
}
/// If the given target is a constant (i.e. it was created by the `constant(F)` method), returns
/// its constant value. Otherwise, returns `None`.
pub fn target_as_constant(&self, target: Target) -> Option<F> {
@ -396,6 +379,20 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
/// The number of (base field) `arithmetic` operations that can be performed in a single gate.
pub(crate) fn num_base_arithmetic_ops_per_gate(&self) -> usize {
if self.config.use_base_arithmetic_gate {
ArithmeticGate::new_from_config(&self.config).num_ops
} else {
self.num_ext_arithmetic_ops_per_gate()
}
}
/// The number of `arithmetic_extension` operations that can be performed in a single gate.
pub(crate) fn num_ext_arithmetic_ops_per_gate(&self) -> usize {
ArithmeticExtensionGate::<D>::new_from_config(&self.config).num_ops
}
/// The number of polynomial values that will be revealed per opening, both for the "regular"
/// polynomials and for the Z polynomials. Because calculating these values involves a recursive
/// dependence (the amount of blinding depends on the degree, which depends on the blinding),
@ -566,76 +563,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
)
}
/// Fill the remaining unused arithmetic operations with zeros, so that all
/// `ArithmeticExtensionGenerator` are run.
fn fill_arithmetic_gates(&mut self) {
let zero = self.zero_extension();
let remaining_arithmetic_gates = self.free_arithmetic.values().copied().collect::<Vec<_>>();
for (gate, i) in remaining_arithmetic_gates {
for j in i..ArithmeticExtensionGate::<D>::num_ops(&self.config) {
let wires_multiplicand_0 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_ith_multiplicand_0(j),
);
let wires_multiplicand_1 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_ith_multiplicand_1(j),
);
let wires_addend = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_ith_addend(j),
);
self.connect_extension(zero, wires_multiplicand_0);
self.connect_extension(zero, wires_multiplicand_1);
self.connect_extension(zero, wires_addend);
}
}
}
/// Fill the remaining unused random access operations with zeros, so that all
/// `RandomAccessGenerator`s are run.
fn fill_random_access_gates(&mut self) {
let zero = self.zero();
for (vec_size, (_, i)) in self.free_random_access.clone() {
let max_copies = RandomAccessGate::<F, D>::max_num_copies(
self.config.num_routed_wires,
self.config.num_wires,
vec_size,
);
for _ in i..max_copies {
self.random_access(zero, zero, vec![zero; vec_size]);
}
}
}
/// Fill the remaining unused switch gates with dummy values, so that all
/// `SwitchGenerator` are run.
fn fill_switch_gates(&mut self) {
let zero = self.zero();
for chunk_size in 1..=self.current_switch_gates.len() {
if let Some((gate, gate_index, mut copy)) =
self.current_switch_gates[chunk_size - 1].clone()
{
while copy < gate.num_copies {
for element in 0..chunk_size {
let wire_first_input =
Target::wire(gate_index, gate.wire_first_input(copy, element));
let wire_second_input =
Target::wire(gate_index, gate.wire_second_input(copy, element));
let wire_switch_bool =
Target::wire(gate_index, gate.wire_switch_bool(copy));
self.connect(zero, wire_first_input);
self.connect(zero, wire_second_input);
self.connect(zero, wire_switch_bool);
}
copy += 1;
}
}
}
}
pub fn print_gate_counts(&self, min_delta: usize) {
// Print gate counts for each context.
self.context_log
@ -659,9 +586,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let mut timing = TimingTree::new("preprocess", Level::Trace);
let start = Instant::now();
self.fill_arithmetic_gates();
self.fill_random_access_gates();
self.fill_switch_gates();
self.fill_batched_gates();
// Hash the public inputs, and route them to a `PublicInputGate` which will enforce that
// those hash wires match the claimed public inputs.
@ -698,7 +623,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
..=1 << self.config.rate_bits)
.min_by_key(|&q| num_partial_products(self.config.num_routed_wires, q).0 + q)
.unwrap();
info!("Quotient degree factor set to: {}.", quotient_degree_factor);
debug!("Quotient degree factor set to: {}.", quotient_degree_factor);
let prefixed_gates = PrefixedGate::from_tree(gate_tree);
let subgroup = F::two_adic_subgroup(degree_bits);
@ -710,7 +635,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
// Precompute FFT roots.
let max_fft_points =
1 << degree_bits + max(self.config.rate_bits, log2_ceil(quotient_degree_factor));
1 << (degree_bits + max(self.config.rate_bits, log2_ceil(quotient_degree_factor)));
let fft_root_table = fft_root_table(max_fft_points);
let constants_sigmas_vecs = [constant_vecs, sigma_vecs.clone()].concat();
@ -745,7 +670,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let watch_rep_index = forest.parents[watch_index];
generator_indices_by_watches
.entry(watch_rep_index)
.or_insert(vec![])
.or_insert_with(Vec::new)
.push(i);
}
}
@ -801,7 +726,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
circuit_digest,
};
info!("Building circuit took {}s", start.elapsed().as_secs_f32());
debug!("Building circuit took {}s", start.elapsed().as_secs_f32());
CircuitData {
prover_only,
verifier_only,
@ -837,3 +762,386 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
}
/// Various gate types can contain multiple copies in a single Gate. This helper struct lets a
/// CircuitBuilder track such gates that are currently being "filled up."
pub struct BatchedGates<F: RichField + Extendable<D>, const D: usize> {
/// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using
/// these constants with gate index `g` and already using `i` arithmetic operations.
pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>,
pub(crate) free_base_arithmetic: HashMap<(F, F), (usize, usize)>,
pub(crate) free_mul: HashMap<F, (usize, usize)>,
/// A map `b -> (g, i)` from `b` bits to an available random access gate of that size with gate
/// index `g` and already using `i` random accesses.
pub(crate) free_random_access: HashMap<usize, (usize, usize)>,
/// `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value
/// chunk_size, and contains `(g, i, c)`, if the gate `g`, at index `i`, already contains `c` copies
/// of switches
pub(crate) current_switch_gates: Vec<Option<(SwitchGate<F, D>, usize, usize)>>,
/// The `U32ArithmeticGate` currently being filled (so new u32 arithmetic operations will be added to this gate before creating a new one)
pub(crate) current_u32_arithmetic_gate: Option<(usize, usize)>,
/// The `U32SubtractionGate` currently being filled (so new u32 subtraction operations will be added to this gate before creating a new one)
pub(crate) current_u32_subtraction_gate: Option<(usize, usize)>,
/// An available `ConstantGate` instance, if any.
pub(crate) free_constant: Option<(usize, usize)>,
}
impl<F: RichField + Extendable<D>, const D: usize> BatchedGates<F, D> {
pub fn new() -> Self {
Self {
free_arithmetic: HashMap::new(),
free_base_arithmetic: HashMap::new(),
free_mul: HashMap::new(),
free_random_access: HashMap::new(),
current_switch_gates: Vec::new(),
current_u32_arithmetic_gate: None,
current_u32_subtraction_gate: None,
free_constant: None,
}
}
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Finds the last available arithmetic gate with the given constants or add one if there aren't any.
/// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index
/// `g` and the gate's `i`-th operation is available.
pub(crate) fn find_base_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) {
let (gate, i) = self
.batched_gates
.free_base_arithmetic
.get(&(const_0, const_1))
.copied()
.unwrap_or_else(|| {
let gate = self.add_gate(
ArithmeticGate::new_from_config(&self.config),
vec![const_0, const_1],
);
(gate, 0)
});
// Update `free_arithmetic` with new values.
if i < ArithmeticGate::num_ops(&self.config) - 1 {
self.batched_gates
.free_base_arithmetic
.insert((const_0, const_1), (gate, i + 1));
} else {
self.batched_gates
.free_base_arithmetic
.remove(&(const_0, const_1));
}
(gate, i)
}
/// Finds the last available arithmetic gate with the given constants or add one if there aren't any.
/// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index
/// `g` and the gate's `i`-th operation is available.
pub(crate) fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) {
let (gate, i) = self
.batched_gates
.free_arithmetic
.get(&(const_0, const_1))
.copied()
.unwrap_or_else(|| {
let gate = self.add_gate(
ArithmeticExtensionGate::new_from_config(&self.config),
vec![const_0, const_1],
);
(gate, 0)
});
// Update `free_arithmetic` with new values.
if i < ArithmeticExtensionGate::<D>::num_ops(&self.config) - 1 {
self.batched_gates
.free_arithmetic
.insert((const_0, const_1), (gate, i + 1));
} else {
self.batched_gates
.free_arithmetic
.remove(&(const_0, const_1));
}
(gate, i)
}
/// Finds the last available arithmetic gate with the given constants or add one if there aren't any.
/// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index
/// `g` and the gate's `i`-th operation is available.
pub(crate) fn find_mul_gate(&mut self, const_0: F) -> (usize, usize) {
let (gate, i) = self
.batched_gates
.free_mul
.get(&const_0)
.copied()
.unwrap_or_else(|| {
let gate = self.add_gate(
MulExtensionGate::new_from_config(&self.config),
vec![const_0],
);
(gate, 0)
});
// Update `free_arithmetic` with new values.
if i < MulExtensionGate::<D>::num_ops(&self.config) - 1 {
self.batched_gates.free_mul.insert(const_0, (gate, i + 1));
} else {
self.batched_gates.free_mul.remove(&const_0);
}
(gate, i)
}
/// Finds the last available random access gate with the given `vec_size` or add one if there aren't any.
/// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index
/// `g` and the gate's `i`-th random access is available.
pub(crate) fn find_random_access_gate(&mut self, bits: usize) -> (usize, usize) {
let (gate, i) = self
.batched_gates
.free_random_access
.get(&bits)
.copied()
.unwrap_or_else(|| {
let gate = self.add_gate(
RandomAccessGate::new_from_config(&self.config, bits),
vec![],
);
(gate, 0)
});
// Update `free_random_access` with new values.
if i + 1 < RandomAccessGate::<F, D>::new_from_config(&self.config, bits).num_copies {
self.batched_gates
.free_random_access
.insert(bits, (gate, i + 1));
} else {
self.batched_gates.free_random_access.remove(&bits);
}
(gate, i)
}
pub(crate) fn find_switch_gate(
&mut self,
chunk_size: usize,
) -> (SwitchGate<F, D>, usize, usize) {
if self.batched_gates.current_switch_gates.len() < chunk_size {
self.batched_gates.current_switch_gates.extend(vec![
None;
chunk_size
- self
.batched_gates
.current_switch_gates
.len()
]);
}
let (gate, gate_index, next_copy) =
match self.batched_gates.current_switch_gates[chunk_size - 1].clone() {
None => {
let gate = SwitchGate::<F, D>::new_from_config(&self.config, chunk_size);
let gate_index = self.add_gate(gate.clone(), vec![]);
(gate, gate_index, 0)
}
Some((gate, idx, next_copy)) => (gate, idx, next_copy),
};
let num_copies = gate.num_copies;
if next_copy == num_copies - 1 {
self.batched_gates.current_switch_gates[chunk_size - 1] = None;
} else {
self.batched_gates.current_switch_gates[chunk_size - 1] =
Some((gate.clone(), gate_index, next_copy + 1));
}
(gate, gate_index, next_copy)
}
pub(crate) fn find_u32_arithmetic_gate(&mut self) -> (usize, usize) {
let (gate_index, copy) = match self.batched_gates.current_u32_arithmetic_gate {
None => {
let gate = U32ArithmeticGate::new_from_config(&self.config);
let gate_index = self.add_gate(gate, vec![]);
(gate_index, 0)
}
Some((gate_index, copy)) => (gate_index, copy),
};
if copy == U32ArithmeticGate::<F, D>::num_ops(&self.config) - 1 {
self.batched_gates.current_u32_arithmetic_gate = None;
} else {
self.batched_gates.current_u32_arithmetic_gate = Some((gate_index, copy + 1));
}
(gate_index, copy)
}
pub(crate) fn find_u32_subtraction_gate(&mut self) -> (usize, usize) {
let (gate_index, copy) = match self.batched_gates.current_u32_subtraction_gate {
None => {
let gate = U32SubtractionGate::new_from_config(&self.config);
let gate_index = self.add_gate(gate, vec![]);
(gate_index, 0)
}
Some((gate_index, copy)) => (gate_index, copy),
};
if copy == U32SubtractionGate::<F, D>::num_ops(&self.config) - 1 {
self.batched_gates.current_u32_subtraction_gate = None;
} else {
self.batched_gates.current_u32_subtraction_gate = Some((gate_index, copy + 1));
}
(gate_index, copy)
}
/// Returns the gate index and copy index of a free `ConstantGate` slot, potentially adding a
/// new `ConstantGate` if needed.
fn constant_gate_instance(&mut self) -> (usize, usize) {
if self.batched_gates.free_constant.is_none() {
let num_consts = self.config.constant_gate_size;
// We will fill this `ConstantGate` with zero constants initially.
// These will be overwritten by `constant` as the gate instances are filled.
let gate = self.add_gate(ConstantGate { num_consts }, vec![F::ZERO; num_consts]);
self.batched_gates.free_constant = Some((gate, 0));
}
let (gate, instance) = self.batched_gates.free_constant.unwrap();
if instance + 1 < self.config.constant_gate_size {
self.batched_gates.free_constant = Some((gate, instance + 1));
} else {
self.batched_gates.free_constant = None;
}
(gate, instance)
}
/// Fill the remaining unused arithmetic operations with zeros, so that all
/// `ArithmeticGate` are run.
fn fill_base_arithmetic_gates(&mut self) {
let zero = self.zero();
for ((c0, c1), (_gate, i)) in self.batched_gates.free_base_arithmetic.clone() {
for _ in i..ArithmeticGate::num_ops(&self.config) {
// If we directly wire in zero, an optimization will skip doing anything and return
// zero. So we pass in a virtual target and connect it to zero afterward.
let dummy = self.add_virtual_target();
self.arithmetic(c0, c1, dummy, dummy, dummy);
self.connect(dummy, zero);
}
}
assert!(self.batched_gates.free_base_arithmetic.is_empty());
}
/// Fill the remaining unused arithmetic operations with zeros, so that all
/// `ArithmeticExtensionGenerator`s are run.
fn fill_arithmetic_gates(&mut self) {
let zero = self.zero_extension();
for ((c0, c1), (_gate, i)) in self.batched_gates.free_arithmetic.clone() {
for _ in i..ArithmeticExtensionGate::<D>::num_ops(&self.config) {
// If we directly wire in zero, an optimization will skip doing anything and return
// zero. So we pass in a virtual target and connect it to zero afterward.
let dummy = self.add_virtual_extension_target();
self.arithmetic_extension(c0, c1, dummy, dummy, dummy);
self.connect_extension(dummy, zero);
}
}
assert!(self.batched_gates.free_arithmetic.is_empty());
}
/// Fill the remaining unused arithmetic operations with zeros, so that all
/// `ArithmeticExtensionGenerator`s are run.
fn fill_mul_gates(&mut self) {
let zero = self.zero_extension();
for (c0, (_gate, i)) in self.batched_gates.free_mul.clone() {
for _ in i..MulExtensionGate::<D>::num_ops(&self.config) {
// If we directly wire in zero, an optimization will skip doing anything and return
// zero. So we pass in a virtual target and connect it to zero afterward.
let dummy = self.add_virtual_extension_target();
self.arithmetic_extension(c0, F::ZERO, dummy, dummy, zero);
self.connect_extension(dummy, zero);
}
}
assert!(self.batched_gates.free_mul.is_empty());
}
/// Fill the remaining unused random access operations with zeros, so that all
/// `RandomAccessGenerator`s are run.
fn fill_random_access_gates(&mut self) {
let zero = self.zero();
for (bits, (_, i)) in self.batched_gates.free_random_access.clone() {
let max_copies =
RandomAccessGate::<F, D>::new_from_config(&self.config, bits).num_copies;
for _ in i..max_copies {
self.random_access(zero, zero, vec![zero; 1 << bits]);
}
}
}
/// Fill the remaining unused switch gates with dummy values, so that all
/// `SwitchGenerator`s are run.
fn fill_switch_gates(&mut self) {
let zero = self.zero();
for chunk_size in 1..=self.batched_gates.current_switch_gates.len() {
if let Some((gate, gate_index, mut copy)) =
self.batched_gates.current_switch_gates[chunk_size - 1].clone()
{
while copy < gate.num_copies {
for element in 0..chunk_size {
let wire_first_input =
Target::wire(gate_index, gate.wire_first_input(copy, element));
let wire_second_input =
Target::wire(gate_index, gate.wire_second_input(copy, element));
let wire_switch_bool =
Target::wire(gate_index, gate.wire_switch_bool(copy));
self.connect(zero, wire_first_input);
self.connect(zero, wire_second_input);
self.connect(zero, wire_switch_bool);
}
copy += 1;
}
}
}
}
/// Fill the remaining unused U32 arithmetic operations with zeros, so that all
/// `U32ArithmeticGenerator`s are run.
fn fill_u32_arithmetic_gates(&mut self) {
let zero = self.zero_u32();
if let Some((_gate_index, copy)) = self.batched_gates.current_u32_arithmetic_gate {
for _ in copy..U32ArithmeticGate::<F, D>::num_ops(&self.config) {
let dummy = self.add_virtual_u32_target();
self.mul_add_u32(dummy, dummy, dummy);
self.connect_u32(dummy, zero);
}
}
}
/// Fill the remaining unused U32 subtraction operations with zeros, so that all
/// `U32SubtractionGenerator`s are run.
fn fill_u32_subtraction_gates(&mut self) {
let zero = self.zero_u32();
if let Some((_gate_index, copy)) = self.batched_gates.current_u32_subtraction_gate {
for _i in copy..U32SubtractionGate::<F, D>::num_ops(&self.config) {
let dummy = self.add_virtual_u32_target();
self.sub_u32(dummy, dummy, dummy);
self.connect_u32(dummy, zero);
}
}
}
fn fill_batched_gates(&mut self) {
self.fill_arithmetic_gates();
self.fill_base_arithmetic_gates();
self.fill_mul_gates();
self.fill_random_access_gates();
self.fill_switch_gates();
self.fill_u32_arithmetic_gates();
self.fill_u32_subtraction_gates();
}
}

View File

@ -16,6 +16,7 @@ use crate::iop::target::Target;
use crate::iop::witness::PartialWitness;
use crate::plonk::config::{GenericConfig, Hasher};
use crate::plonk::proof::ProofWithPublicInputs;
use crate::plonk::proof::{CompressedProofWithPublicInputs, ProofWithPublicInputs};
use crate::plonk::prover::prove;
use crate::plonk::verifier::verify;
use crate::util::marking::MarkedTargets;
@ -26,6 +27,9 @@ pub struct CircuitConfig {
pub num_wires: usize,
pub num_routed_wires: usize,
pub constant_gate_size: usize,
/// Whether to use a dedicated gate for base field arithmetic, rather than using a single gate
/// for both base field and extension field arithmetic.
pub use_base_arithmetic_gate: bool,
pub security_bits: usize,
pub rate_bits: usize,
/// The number of challenge points to generate, for IOPs that have soundness errors of (roughly)
@ -45,30 +49,35 @@ impl Default for CircuitConfig {
}
impl CircuitConfig {
pub fn rate(&self) -> f64 {
1.0 / ((1 << self.rate_bits) as f64)
}
pub fn num_advice_wires(&self) -> usize {
self.num_wires - self.num_routed_wires
}
/// A typical recursion config, without zero-knowledge, targeting ~100 bit security.
pub(crate) fn standard_recursion_config() -> Self {
pub fn standard_recursion_config() -> Self {
Self {
num_wires: 143,
num_routed_wires: 25,
constant_gate_size: 6,
num_wires: 135,
num_routed_wires: 80,
constant_gate_size: 5,
use_base_arithmetic_gate: true,
security_bits: 100,
rate_bits: 3,
num_challenges: 2,
zero_knowledge: false,
cap_height: 3,
cap_height: 4,
fri_config: FriConfig {
proof_of_work_bits: 16,
reduction_strategy: FriReductionStrategy::ConstantArityBits(3, 5),
reduction_strategy: FriReductionStrategy::ConstantArityBits(4, 5),
num_query_rounds: 28,
},
}
}
pub(crate) fn standard_recursion_zk_config() -> Self {
pub fn standard_recursion_zk_config() -> Self {
CircuitConfig {
zero_knowledge: true,
..Self::standard_recursion_config()
@ -96,6 +105,13 @@ impl<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize> CircuitData<F
pub fn verify(&self, proof_with_pis: ProofWithPublicInputs<F, C, D>) -> Result<()> {
verify(proof_with_pis, &self.verifier_only, &self.common)
}
pub fn verify_compressed(
&self,
compressed_proof_with_pis: CompressedProofWithPublicInputs<F, D>,
) -> Result<()> {
compressed_proof_with_pis.verify(&self.verifier_only, &self.common)
}
}
/// Circuit data required by the prover. This may be thought of as a proving key, although it
@ -132,6 +148,13 @@ impl<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize> VerifierCircu
pub fn verify(&self, proof_with_pis: ProofWithPublicInputs<F, C, D>) -> Result<()> {
verify(proof_with_pis, &self.verifier_only, &self.common)
}
pub fn verify_compressed(
&self,
compressed_proof_with_pis: CompressedProofWithPublicInputs<F, D>,
) -> Result<()> {
compressed_proof_with_pis.verify(&self.verifier_only, &self.common)
}
}
/// Circuit data required by the prover, but not the verifier.
@ -194,8 +217,8 @@ pub struct CommonCircuitData<F: Extendable<D>, C: GenericConfig<D, F = F>, const
/// The `{k_i}` valued used in `S_ID_i` in Plonk's permutation argument.
pub(crate) k_is: Vec<F>,
/// The number of partial products needed to compute the `Z` polynomials and the number
/// of partial products needed to compute the final product.
/// The number of partial products needed to compute the `Z` polynomials and
/// the number of original elements consumed in `partial_products()`.
pub(crate) num_partial_products: (usize, usize),
/// A digest of the "circuit" (i.e. the instance, minus public inputs), which can be used to
@ -228,11 +251,6 @@ impl<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize> CommonCircuit
self.quotient_degree_factor * self.degree()
}
pub fn total_constraints(&self) -> usize {
// 2 constraints for each Z check.
self.config.num_challenges * 2 + self.num_gate_constraints
}
/// Range of the constants polynomials in the `constants_sigmas_commitment`.
pub fn constants_range(&self) -> Range<usize> {
0..self.num_constants

View File

@ -11,7 +11,7 @@ use crate::plonk::proof::{
CompressedProof, CompressedProofWithPublicInputs, FriInferredElements, OpeningSet, Proof,
ProofChallenges, ProofWithPublicInputs,
};
use crate::polynomial::polynomial::PolynomialCoeffs;
use crate::polynomial::PolynomialCoeffs;
use crate::util::reverse_bits;
fn get_challenges<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>(

View File

@ -5,7 +5,7 @@ use rayon::prelude::*;
use crate::field::field_types::Field;
use crate::iop::target::Target;
use crate::iop::wire::Wire;
use crate::polynomial::polynomial::PolynomialValues;
use crate::polynomial::PolynomialValues;
/// Disjoint Set Forest data-structure following https://en.wikipedia.org/wiki/Disjoint-set_data_structure.
pub struct Forest {
@ -45,15 +45,23 @@ impl Forest {
}
/// Path compression method, see https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Finding_set_representatives.
pub fn find(&mut self, x_index: usize) -> usize {
let x_parent = self.parents[x_index];
if x_parent != x_index {
let root_index = self.find(x_parent);
self.parents[x_index] = root_index;
root_index
} else {
x_index
pub fn find(&mut self, mut x_index: usize) -> usize {
// Note: We avoid recursion here since the chains can be long, causing stack overflows.
// First, find the representative of the set containing `x_index`.
let mut representative = x_index;
while self.parents[representative] != representative {
representative = self.parents[representative];
}
// Then, update each node in this chain to point directly to the representative.
while self.parents[x_index] != x_index {
let old_parent = self.parents[x_index];
self.parents[x_index] = representative;
x_index = old_parent;
}
representative
}
/// Merge two sets.

View File

@ -42,17 +42,6 @@ impl PlonkPolynomials {
index: 3,
blinding: true,
};
#[cfg(test)]
pub fn polynomials(i: usize) -> PolynomialsIndexBlinding {
match i {
0 => Self::CONSTANTS_SIGMAS,
1 => Self::WIRES,
2 => Self::ZS_PARTIAL_PRODUCTS,
3 => Self::QUOTIENT,
_ => panic!("There are only 4 sets of polynomials in Plonk."),
}
}
}
/// Evaluate the polynomial which vanishes on any multiplicative subgroup of a given order `n`.

View File

@ -164,12 +164,12 @@ impl<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
) -> anyhow::Result<ProofWithPublicInputs<F, C, D>> {
let challenges = self.get_challenges(common_data)?;
let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data);
let compressed_proof =
let decompressed_proof =
self.proof
.decompress(&challenges, fri_inferred_elements, common_data);
Ok(ProofWithPublicInputs {
public_inputs: self.public_inputs,
proof: compressed_proof,
proof: decompressed_proof,
})
}
@ -180,13 +180,13 @@ impl<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
) -> anyhow::Result<()> {
let challenges = self.get_challenges(common_data)?;
let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data);
let compressed_proof =
let decompressed_proof =
self.proof
.decompress(&challenges, fri_inferred_elements, common_data);
verify_with_challenges(
ProofWithPublicInputs {
public_inputs: self.public_inputs,
proof: compressed_proof,
proof: decompressed_proof,
},
challenges,
verifier_data,
@ -312,6 +312,7 @@ mod tests {
use crate::field::field_types::Field;
use crate::fri::reduction_strategies::FriReductionStrategy;
use crate::gates::noop::NoopGate;
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
@ -340,6 +341,9 @@ mod tests {
let zt = builder.constant(z);
let comp_zt = builder.mul(xt, yt);
builder.connect(zt, comp_zt);
for _ in 0..100 {
builder.add_gate(NoopGate, vec![]);
}
let data = builder.build::<C>();
let proof = data.prove(pw)?;
verify(proof.clone(), &data.verifier_only, &data.common)?;
@ -350,6 +354,6 @@ mod tests {
assert_eq!(proof, decompressed_compressed_proof);
verify(proof, &data.verifier_only, &data.common)?;
compressed_proof.verify(&data.verifier_only, &data.common)
data.verify_compressed(compressed_proof)
}
}

View File

@ -1,3 +1,5 @@
use std::mem::swap;
use anyhow::Result;
use rayon::prelude::*;
@ -13,9 +15,9 @@ use crate::plonk::plonk_common::ZeroPolyOnCoset;
use crate::plonk::proof::{Proof, ProofWithPublicInputs};
use crate::plonk::vanishing_poly::eval_vanishing_poly_base_batch;
use crate::plonk::vars::EvaluationVarsBase;
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::timed;
use crate::util::partial_products::partial_products;
use crate::util::partial_products::{partial_products_and_z_gx, quotient_chunk_products};
use crate::util::timing::TimingTree;
use crate::util::{log2_ceil, transpose};
@ -89,28 +91,22 @@ pub(crate) fn prove<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize
common_data.quotient_degree_factor < common_data.config.num_routed_wires,
"When the number of routed wires is smaller that the degree, we should change the logic to avoid computing partial products."
);
let mut partial_products = timed!(
let mut partial_products_and_zs = timed!(
timing,
"compute partial products",
all_wires_permutation_partial_products(&witness, &betas, &gammas, prover_data, common_data)
);
let plonk_z_vecs = timed!(
timing,
"compute Z's",
compute_zs(&partial_products, common_data)
);
// Z is expected at the front of our batch; see `zs_range` and `partial_products_range`.
let plonk_z_vecs = partial_products_and_zs
.iter_mut()
.map(|partial_products_and_z| partial_products_and_z.pop().unwrap())
.collect();
let zs_partial_products = [plonk_z_vecs, partial_products_and_zs.concat()].concat();
// The first polynomial in `partial_products` represent the final product used in the
// computation of `Z`. It isn't needed anymore so we discard it.
partial_products.iter_mut().for_each(|part| {
part.remove(0);
});
let zs_partial_products = [plonk_z_vecs, partial_products.concat()].concat();
let zs_partial_products_commitment = timed!(
let partial_products_and_zs_commitment = timed!(
timing,
"commit to Z's",
"commit to partial products and Z's",
PolynomialBatchCommitment::from_values(
zs_partial_products,
config.rate_bits,
@ -121,7 +117,7 @@ pub(crate) fn prove<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize
)
);
challenger.observe_cap(&zs_partial_products_commitment.merkle_tree.cap);
challenger.observe_cap(&partial_products_and_zs_commitment.merkle_tree.cap);
let alphas = challenger.get_n_challenges(num_challenges);
@ -133,7 +129,7 @@ pub(crate) fn prove<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize
prover_data,
&public_inputs_hash,
&wires_commitment,
&zs_partial_products_commitment,
&partial_products_and_zs_commitment,
&betas,
&gammas,
&alphas,
@ -148,7 +144,6 @@ pub(crate) fn prove<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize
.into_par_iter()
.flat_map(|mut quotient_poly| {
quotient_poly.trim();
// TODO: Return Result instead of panicking.
quotient_poly.pad(quotient_degree).expect(
"Quotient has failed, the vanishing polynomial is not divisible by `Z_H",
);
@ -182,7 +177,7 @@ pub(crate) fn prove<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize
&[
&prover_data.constants_sigmas_commitment,
&wires_commitment,
&zs_partial_products_commitment,
&partial_products_and_zs_commitment,
&quotient_polys_commitment,
],
zeta,
@ -194,7 +189,7 @@ pub(crate) fn prove<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize
let proof = Proof {
wires_cap: wires_commitment.merkle_tree.cap,
plonk_zs_partial_products_cap: zs_partial_products_commitment.merkle_tree.cap,
plonk_zs_partial_products_cap: partial_products_and_zs_commitment.merkle_tree.cap,
quotient_polys_cap: quotient_polys_commitment.merkle_tree.cap,
openings,
opening_proof,
@ -219,7 +214,7 @@ fn all_wires_permutation_partial_products<
) -> Vec<Vec<PolynomialValues<F>>> {
(0..common_data.config.num_challenges)
.map(|i| {
wires_permutation_partial_products(
wires_permutation_partial_products_and_zs(
witness,
betas[i],
gammas[i],
@ -233,7 +228,7 @@ fn all_wires_permutation_partial_products<
/// Compute the partial products used in the `Z` polynomial.
/// Returns the polynomials interpolating `partial_products(f / g)`
/// where `f, g` are the products in the definition of `Z`: `Z(g^i) = f / g`.
fn wires_permutation_partial_products<
fn wires_permutation_partial_products_and_zs<
F: Extendable<D>,
C: GenericConfig<D, F = F>,
const D: usize,
@ -247,7 +242,8 @@ fn wires_permutation_partial_products<
let degree = common_data.quotient_degree_factor;
let subgroup = &prover_data.subgroup;
let k_is = &common_data.k_is;
let values = subgroup
let (num_prods, _final_num_prod) = common_data.num_partial_products;
let all_quotient_chunk_products = subgroup
.par_iter()
.enumerate()
.map(|(i, &x)| {
@ -271,49 +267,26 @@ fn wires_permutation_partial_products<
.map(|(num, den_inv)| num * den_inv)
.collect::<Vec<_>>();
let quotient_partials = partial_products(&quotient_values, degree);
// This is the final product for the quotient.
let quotient = quotient_partials
[common_data.num_partial_products.0 - common_data.num_partial_products.1..]
.iter()
.copied()
.product();
// We add the quotient at the beginning of the vector to reuse them later in the computation of `Z`.
[vec![quotient], quotient_partials].concat()
quotient_chunk_products(&quotient_values, degree)
})
.collect::<Vec<_>>();
transpose(&values)
let mut z_x = F::ONE;
let mut all_partial_products_and_zs = Vec::new();
for quotient_chunk_products in all_quotient_chunk_products {
let mut partial_products_and_z_gx =
partial_products_and_z_gx(z_x, &quotient_chunk_products);
// The last term is Z(gx), but we replace it with Z(x), otherwise Z would end up shifted.
swap(&mut z_x, &mut partial_products_and_z_gx[num_prods]);
all_partial_products_and_zs.push(partial_products_and_z_gx);
}
transpose(&all_partial_products_and_zs)
.into_par_iter()
.map(PolynomialValues::new)
.collect()
}
fn compute_zs<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>(
partial_products: &[Vec<PolynomialValues<F>>],
common_data: &CommonCircuitData<F, C, D>,
) -> Vec<PolynomialValues<F>> {
(0..common_data.config.num_challenges)
.map(|i| compute_z(&partial_products[i], common_data))
.collect()
}
/// Compute the `Z` polynomial by reusing the computations done in `wires_permutation_partial_products`.
fn compute_z<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>(
partial_products: &[PolynomialValues<F>],
common_data: &CommonCircuitData<F, C, D>,
) -> PolynomialValues<F> {
let mut plonk_z_points = vec![F::ONE];
for i in 1..common_data.degree() {
let quotient = partial_products[0].values[i - 1];
let last = *plonk_z_points.last().unwrap();
plonk_z_points.push(last * quotient);
}
plonk_z_points.into()
}
const BATCH_SIZE: usize = 32;
fn compute_quotient_polys<'a, F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>(

View File

@ -133,6 +133,7 @@ mod tests {
use crate::fri::reduction_strategies::FriReductionStrategy;
use crate::fri::FriConfig;
use crate::gadgets::polynomial::PolynomialCoeffsExtTarget;
use crate::gates::noop::NoopGate;
use crate::hash::merkle_proofs::MerkleProofTarget;
use crate::iop::witness::{PartialWitness, Witness};
use crate::plonk::circuit_data::VerifierOnlyCircuitData;
@ -369,9 +370,8 @@ mod tests {
type F = <C as GenericConfig<D>>::F;
let config = CircuitConfig::standard_recursion_config();
let (proof, vd, cd) = dummy_proof::<F, C, D>(&config, 8_000)?;
let (proof, _vd, cd) =
recursive_proof::<F, C, C, D>(proof, vd, cd, &config, &config, true, true)?;
let (proof, vd, cd) = dummy_proof::<F,C, D>(&config, 4_000)?;
let (proof, _vd, cd) = recursive_proof::<F,C,C,D>(proof, vd, cd, &config, &config, None, true, true)?;
test_serialization(&proof, &cd)?;
Ok(())
@ -388,11 +388,14 @@ mod tests {
let config = CircuitConfig::standard_recursion_config();
let (proof, vd, cd) = dummy_proof::<F, C, D>(&config, 8_000)?;
// Start with a degree 2^14 proof, then shrink it to 2^13, then to 2^12.
let (proof, vd, cd) = dummy_proof::<F,C, D>(&config, 16_000)?;
assert_eq!(cd.degree_bits, 14);
let (proof, vd, cd) =
recursive_proof::<F, C, C, D>(proof, vd, cd, &config, &config, false, false)?;
let (proof, _vd, cd) =
recursive_proof::<F, KC, C, D>(proof, vd, cd, &config, &config, true, true)?;
recursive_proof::<F,C,C,D>(proof, vd, cd, &config, &config, Some(13), false, false)?;
assert_eq!(cd.degree_bits, 13);
let (proof, _vd, cd) = recursive_proof::<F,KC,C,D>(proof, vd, cd, &config, &config, None, true, true)?;
assert_eq!(cd.degree_bits, 12);
test_serialization(&proof, &cd)?;
@ -412,29 +415,29 @@ mod tests {
let standard_config = CircuitConfig::standard_recursion_config();
// A dummy proof with degree 2^13.
let (proof, vd, cd) = dummy_proof::<F, C, D>(&standard_config, 8_000)?;
assert_eq!(cd.degree_bits, 13);
// An initial dummy proof.
let (proof, vd, cd) = dummy_proof::<F,C, D>(&standard_config, 4_000)?;
assert_eq!(cd.degree_bits, 12);
// A standard recursive proof with degree 2^13.
// A standard recursive proof.
let (proof, vd, cd) = recursive_proof(
proof,
vd,
cd,
&standard_config,
&standard_config,
None,
false,
false,
)?;
assert_eq!(cd.degree_bits, 13);
assert_eq!(cd.degree_bits, 12);
// A high-rate recursive proof with degree 2^13, designed to be verifiable with 2^12
// gates and 48 routed wires.
// A high-rate recursive proof, designed to be verifiable with fewer routed wires.
let high_rate_config = CircuitConfig {
rate_bits: 5,
rate_bits: 7,
fri_config: FriConfig {
proof_of_work_bits: 20,
num_query_rounds: 16,
proof_of_work_bits: 16,
num_query_rounds: 12,
..standard_config.fri_config.clone()
},
..standard_config
@ -445,54 +448,35 @@ mod tests {
cd,
&standard_config,
&high_rate_config,
true,
true,
)?;
assert_eq!(cd.degree_bits, 13);
// A higher-rate recursive proof with degree 2^12, designed to be verifiable with 2^12
// gates and 28 routed wires.
let higher_rate_more_routing_config = CircuitConfig {
rate_bits: 7,
num_routed_wires: 48,
fri_config: FriConfig {
proof_of_work_bits: 23,
num_query_rounds: 11,
..standard_config.fri_config.clone()
},
..high_rate_config.clone()
};
let (proof, vd, cd) = recursive_proof::<F, C, C, D>(
proof,
vd,
cd,
&high_rate_config,
&higher_rate_more_routing_config,
None,
true,
true,
)?;
assert_eq!(cd.degree_bits, 12);
// A final proof of degree 2^12, optimized for size.
// A final proof, optimized for size.
let final_config = CircuitConfig {
cap_height: 0,
num_routed_wires: 32,
rate_bits: 8,
num_routed_wires: 37,
fri_config: FriConfig {
proof_of_work_bits: 20,
reduction_strategy: FriReductionStrategy::MinSize(None),
..higher_rate_more_routing_config.fri_config.clone()
num_query_rounds: 10,
},
..higher_rate_more_routing_config
..high_rate_config
};
let (proof, _vd, cd) = recursive_proof::<F, KC, C, D>(
proof,
vd,
cd,
&higher_rate_more_routing_config,
&high_rate_config,
&final_config,
None,
true,
true,
)?;
assert_eq!(cd.degree_bits, 12);
assert_eq!(cd.degree_bits, 12, "final proof too large");
test_serialization(&proof, &cd)?;
@ -509,16 +493,12 @@ mod tests {
CommonCircuitData<F, C, D>,
)> {
let mut builder = CircuitBuilder::<F, D>::new(config.clone());
let input = builder.add_virtual_target();
for i in 0..num_dummy_gates {
// Use unique constants to force a new `ArithmeticGate`.
let i_f = F::from_canonical_u64(i);
builder.arithmetic(i_f, i_f, input, input, input);
for _ in 0..num_dummy_gates {
builder.add_gate(NoopGate, vec![]);
}
let data = builder.build::<C>();
let mut inputs = PartialWitness::new();
inputs.set_target(input, F::ZERO);
let inputs = PartialWitness::new();
let proof = data.prove(inputs)?;
data.verify(proof.clone())?;
@ -536,6 +516,7 @@ mod tests {
inner_cd: CommonCircuitData<F, InnerC, D>,
inner_config: &CircuitConfig,
config: &CircuitConfig,
min_degree_bits: Option<usize>,
print_gate_counts: bool,
print_timing: bool,
) -> Result<(
@ -556,12 +537,22 @@ mod tests {
&inner_vd.constants_sigmas_cap,
);
builder.add_recursive_verifier(pt, &inner_config, &inner_data, &inner_cd);
builder.add_recursive_verifier(pt, inner_config, &inner_data, &inner_cd);
if print_gate_counts {
builder.print_gate_counts(0);
}
if let Some(min_degree_bits) = min_degree_bits {
// We don't want to pad all the way up to 2^min_degree_bits, as the builder will add a
// few special gates afterward. So just pad to 2^(min_degree_bits - 1) + 1. Then the
// builder will pad to the next power of two, 2^min_degree_bits.
let min_gates = (1 << (min_degree_bits - 1)) + 1;
for _ in builder.num_gates()..min_gates {
builder.add_gate(NoopGate, vec![]);
}
}
let data = builder.build::<C>();
let mut timing = TimingTree::new("prove", Level::Debug);
@ -582,12 +573,12 @@ mod tests {
) -> Result<()> {
let proof_bytes = proof.to_bytes()?;
info!("Proof length: {} bytes", proof_bytes.len());
let proof_from_bytes = ProofWithPublicInputs::from_bytes(proof_bytes, &cd)?;
let proof_from_bytes = ProofWithPublicInputs::from_bytes(proof_bytes, cd)?;
assert_eq!(proof, &proof_from_bytes);
let now = std::time::Instant::now();
let compressed_proof = proof.clone().compress(&cd)?;
let decompressed_compressed_proof = compressed_proof.clone().decompress(&cd)?;
let compressed_proof = proof.clone().compress(cd)?;
let decompressed_compressed_proof = compressed_proof.clone().decompress(cd)?;
info!("{:.4}s to compress proof", now.elapsed().as_secs_f64());
assert_eq!(proof, &decompressed_compressed_proof);
@ -597,7 +588,7 @@ mod tests {
compressed_proof_bytes.len()
);
let compressed_proof_from_bytes =
CompressedProofWithPublicInputs::from_bytes(compressed_proof_bytes, &cd)?;
CompressedProofWithPublicInputs::from_bytes(compressed_proof_bytes, cd)?;
assert_eq!(compressed_proof, compressed_proof_from_bytes);
Ok(())

View File

@ -29,7 +29,7 @@ pub(crate) fn eval_vanishing_poly<F: Extendable<D>, C: GenericConfig<D, F = F>,
alphas: &[F],
) -> Vec<F::Extension> {
let max_degree = common_data.quotient_degree_factor;
let (num_prods, final_num_prod) = common_data.num_partial_products;
let (num_prods, _final_num_prod) = common_data.num_partial_products;
let constraint_terms =
evaluate_gate_constraints(&common_data.gates, common_data.num_gate_constraints, vars);
@ -38,14 +38,12 @@ pub(crate) fn eval_vanishing_poly<F: Extendable<D>, C: GenericConfig<D, F = F>,
let mut vanishing_z_1_terms = Vec::new();
// The terms checking the partial products.
let mut vanishing_partial_products_terms = Vec::new();
// The Z(x) f'(x) - g'(x) Z(g x) terms.
let mut vanishing_v_shift_terms = Vec::new();
let l1_x = plonk_common::eval_l_1(common_data.degree(), x);
for i in 0..common_data.config.num_challenges {
let z_x = local_zs[i];
let z_gz = next_zs[i];
let z_gx = next_zs[i];
vanishing_z_1_terms.push(l1_x * (z_x - F::Extension::ONE));
let numerator_values = (0..common_data.config.num_routed_wires)
@ -63,37 +61,24 @@ pub(crate) fn eval_vanishing_poly<F: Extendable<D>, C: GenericConfig<D, F = F>,
wire_value + s_sigma.scalar_mul(betas[i]) + gammas[i].into()
})
.collect::<Vec<_>>();
let quotient_values = (0..common_data.config.num_routed_wires)
.map(|j| numerator_values[j] / denominator_values[j])
.collect::<Vec<_>>();
// The partial products considered for this iteration of `i`.
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
// Check the quotient partial products.
let mut partial_product_check =
check_partial_products(&quotient_values, current_partial_products, max_degree);
// The first checks are of the form `q - n/d` which is a rational function not a polynomial.
// We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials.
denominator_values
.chunks(max_degree)
.zip(partial_product_check.iter_mut())
.for_each(|(d, q)| {
*q *= d.iter().copied().product();
});
vanishing_partial_products_terms.extend(partial_product_check);
// The quotient final product is the product of the last `final_num_prod` elements.
let quotient: F::Extension = current_partial_products[num_prods - final_num_prod..]
.iter()
.copied()
.product();
vanishing_v_shift_terms.push(quotient * z_x - z_gz);
let partial_product_checks = check_partial_products(
&numerator_values,
&denominator_values,
current_partial_products,
z_x,
z_gx,
max_degree,
);
vanishing_partial_products_terms.extend(partial_product_checks);
}
let vanishing_terms = [
vanishing_z_1_terms,
vanishing_partial_products_terms,
vanishing_v_shift_terms,
constraint_terms,
]
.concat();
@ -130,7 +115,7 @@ pub(crate) fn eval_vanishing_poly_base_batch<
assert_eq!(s_sigmas_batch.len(), n);
let max_degree = common_data.quotient_degree_factor;
let (num_prods, final_num_prod) = common_data.num_partial_products;
let (num_prods, _final_num_prod) = common_data.num_partial_products;
let num_gate_constraints = common_data.num_gate_constraints;
@ -143,14 +128,11 @@ pub(crate) fn eval_vanishing_poly_base_batch<
let mut numerator_values = Vec::with_capacity(num_routed_wires);
let mut denominator_values = Vec::with_capacity(num_routed_wires);
let mut quotient_values = Vec::with_capacity(num_routed_wires);
// The L_1(x) (Z(x) - 1) vanishing terms.
let mut vanishing_z_1_terms = Vec::with_capacity(num_challenges);
// The terms checking the partial products.
let mut vanishing_partial_products_terms = Vec::new();
// The Z(x) f'(x) - g'(x) Z(g x) terms.
let mut vanishing_v_shift_terms = Vec::with_capacity(num_challenges);
let mut res_batch: Vec<Vec<F>> = Vec::with_capacity(n);
for k in 0..n {
@ -168,7 +150,7 @@ pub(crate) fn eval_vanishing_poly_base_batch<
let l1_x = z_h_on_coset.eval_l1(index, x);
for i in 0..num_challenges {
let z_x = local_zs[i];
let z_gz = next_zs[i];
let z_gx = next_zs[i];
vanishing_z_1_terms.push(l1_x * z_x.sub_one());
numerator_values.extend((0..num_routed_wires).map(|j| {
@ -182,49 +164,33 @@ pub(crate) fn eval_vanishing_poly_base_batch<
let s_sigma = s_sigmas[j];
wire_value + betas[i] * s_sigma + gammas[i]
}));
let denominator_inverses = F::batch_multiplicative_inverse(&denominator_values);
quotient_values.extend(
(0..num_routed_wires).map(|j| numerator_values[j] * denominator_inverses[j]),
);
// The partial products considered for this iteration of `i`.
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
// Check the numerator partial products.
let mut partial_product_check =
check_partial_products(&quotient_values, current_partial_products, max_degree);
// The first checks are of the form `q - n/d` which is a rational function not a polynomial.
// We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials.
denominator_values
.chunks(max_degree)
.zip(partial_product_check.iter_mut())
.for_each(|(d, q)| {
*q *= d.iter().copied().product();
});
vanishing_partial_products_terms.extend(partial_product_check);
// The quotient final product is the product of the last `final_num_prod` elements.
let quotient: F = current_partial_products[num_prods - final_num_prod..]
.iter()
.copied()
.product();
vanishing_v_shift_terms.push(quotient * z_x - z_gz);
let partial_product_checks = check_partial_products(
&numerator_values,
&denominator_values,
current_partial_products,
z_x,
z_gx,
max_degree,
);
vanishing_partial_products_terms.extend(partial_product_checks);
numerator_values.clear();
denominator_values.clear();
quotient_values.clear();
}
let vanishing_terms = vanishing_z_1_terms
.iter()
.chain(vanishing_partial_products_terms.iter())
.chain(vanishing_v_shift_terms.iter())
.chain(constraint_terms);
let res = plonk_common::reduce_with_powers_multi(vanishing_terms, alphas);
res_batch.push(res);
vanishing_z_1_terms.clear();
vanishing_partial_products_terms.clear();
vanishing_v_shift_terms.clear();
}
res_batch
}
@ -334,7 +300,7 @@ pub(crate) fn eval_vanishing_poly_recursively<
alphas: &[Target],
) -> Vec<ExtensionTarget<D>> {
let max_degree = common_data.quotient_degree_factor;
let (num_prods, final_num_prod) = common_data.num_partial_products;
let (num_prods, _final_num_prod) = common_data.num_partial_products;
let constraint_terms = with_context!(
builder,
@ -351,8 +317,6 @@ pub(crate) fn eval_vanishing_poly_recursively<
let mut vanishing_z_1_terms = Vec::new();
// The terms checking the partial products.
let mut vanishing_partial_products_terms = Vec::new();
// The Z(x) f'(x) - g'(x) Z(g x) terms.
let mut vanishing_v_shift_terms = Vec::new();
let l1_x = eval_l_1_recursively(builder, common_data.degree(), x, x_pow_deg);
@ -365,14 +329,13 @@ pub(crate) fn eval_vanishing_poly_recursively<
for i in 0..common_data.config.num_challenges {
let z_x = local_zs[i];
let z_gz = next_zs[i];
let z_gx = next_zs[i];
// L_1(x) Z(x) = 0.
vanishing_z_1_terms.push(builder.mul_sub_extension(l1_x, z_x, l1_x));
let mut numerator_values = Vec::new();
let mut denominator_values = Vec::new();
let mut quotient_values = Vec::new();
for j in 0..common_data.config.num_routed_wires {
let wire_value = vars.local_wires[j];
@ -385,44 +348,28 @@ pub(crate) fn eval_vanishing_poly_recursively<
let numerator = builder.mul_add_extension(beta_ext, s_ids[j], wire_value_plus_gamma);
let denominator =
builder.mul_add_extension(beta_ext, s_sigmas[j], wire_value_plus_gamma);
let quotient = builder.div_extension(numerator, denominator);
numerator_values.push(numerator);
denominator_values.push(denominator);
quotient_values.push(quotient);
}
// The partial products considered for this iteration of `i`.
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
// Check the quotient partial products.
let mut partial_product_check = check_partial_products_recursively(
let partial_product_checks = check_partial_products_recursively(
builder,
&quotient_values,
&numerator_values,
&denominator_values,
current_partial_products,
z_x,
z_gx,
max_degree,
);
// The first checks are of the form `q - n/d` which is a rational function not a polynomial.
// We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials.
denominator_values
.chunks(max_degree)
.zip(partial_product_check.iter_mut())
.for_each(|(d, q)| {
let mut v = d.to_vec();
v.push(*q);
*q = builder.mul_many_extension(&v);
});
vanishing_partial_products_terms.extend(partial_product_check);
// The quotient final product is the product of the last `final_num_prod` elements.
let quotient =
builder.mul_many_extension(&current_partial_products[num_prods - final_num_prod..]);
vanishing_v_shift_terms.push(builder.mul_sub_extension(quotient, z_x, z_gz));
vanishing_partial_products_terms.extend(partial_product_checks);
}
let vanishing_terms = [
vanishing_z_1_terms,
vanishing_partial_products_terms,
vanishing_v_shift_terms,
constraint_terms,
]
.concat();

View File

@ -1,4 +1,3 @@
use std::convert::TryInto;
use std::ops::Range;
use crate::field::extension_field::algebra::ExtensionAlgebra;

View File

@ -1,7 +1,6 @@
use crate::field::fft::{fft, ifft};
use crate::field::field_types::Field;
use crate::polynomial::polynomial::PolynomialCoeffs;
use crate::util::{log2_ceil, log2_strict};
use crate::polynomial::PolynomialCoeffs;
use crate::util::log2_ceil;
impl<F: Field> PolynomialCoeffs<F> {
/// Polynomial division.
@ -67,63 +66,6 @@ impl<F: Field> PolynomialCoeffs<F> {
}
}
/// Takes a polynomial `a` in coefficient form, and divides it by `Z_H = X^n - 1`.
///
/// This assumes `Z_H | a`, otherwise result is meaningless.
pub(crate) fn divide_by_z_h(&self, n: usize) -> PolynomialCoeffs<F> {
let mut a = self.clone();
// TODO: Is this special case needed?
if a.coeffs.iter().all(|p| *p == F::ZERO) {
return a;
}
let g = F::MULTIPLICATIVE_GROUP_GENERATOR;
let mut g_pow = F::ONE;
// Multiply the i-th coefficient of `a` by `g^i`. Then `new_a(w^j) = old_a(g.w^j)`.
a.coeffs.iter_mut().for_each(|x| {
*x *= g_pow;
g_pow *= g;
});
let root = F::primitive_root_of_unity(log2_strict(a.len()));
// Equals to the evaluation of `a` on `{g.w^i}`.
let mut a_eval = fft(&a);
// Compute the denominators `1/(g^n.w^(n*i) - 1)` using batch inversion.
let denominator_g = g.exp_u64(n as u64);
let root_n = root.exp_u64(n as u64);
let mut root_pow = F::ONE;
let denominators = (0..a_eval.len())
.map(|i| {
if i != 0 {
root_pow *= root_n;
}
denominator_g * root_pow - F::ONE
})
.collect::<Vec<_>>();
let denominators_inv = F::batch_multiplicative_inverse(&denominators);
// Divide every element of `a_eval` by the corresponding denominator.
// Then, `a_eval` is the evaluation of `a/Z_H` on `{g.w^i}`.
a_eval
.values
.iter_mut()
.zip(denominators_inv.iter())
.for_each(|(x, &d)| {
*x *= d;
});
// `p` is the interpolating polynomial of `a_eval` on `{w^i}`.
let mut p = ifft(&a_eval);
// We need to scale it by `g^(-i)` to get the interpolating polynomial of `a_eval` on `{g.w^i}`,
// a.k.a `a/Z_H`.
let g_inv = g.inverse();
let mut g_inv_pow = F::ONE;
p.coeffs.iter_mut().for_each(|x| {
*x *= g_inv_pow;
g_inv_pow *= g_inv;
});
p
}
/// Let `self=p(X)`, this returns `(p(X)-p(z))/(X-z)` and `p(z)`.
/// See https://en.wikipedia.org/wiki/Horner%27s_method
pub(crate) fn divide_by_linear(&self, z: F) -> (PolynomialCoeffs<F>, F) {
@ -187,35 +129,7 @@ mod tests {
use crate::field::extension_field::quartic::QuarticExtension;
use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField;
use crate::polynomial::polynomial::PolynomialCoeffs;
#[test]
fn zero_div_z_h() {
type F = GoldilocksField;
let zero = PolynomialCoeffs::<F>::zero(16);
let quotient = zero.divide_by_z_h(4);
assert_eq!(quotient, zero);
}
#[test]
fn division_by_z_h() {
type F = GoldilocksField;
let zero = F::ZERO;
let three = F::from_canonical_u64(3);
let four = F::from_canonical_u64(4);
let five = F::from_canonical_u64(5);
let six = F::from_canonical_u64(6);
// a(x) = Z_4(x) q(x), where
// a(x) = 3 x^7 + 4 x^6 + 5 x^5 + 6 x^4 - 3 x^3 - 4 x^2 - 5 x - 6
// Z_4(x) = x^4 - 1
// q(x) = 3 x^3 + 4 x^2 + 5 x + 6
let a = PolynomialCoeffs::new(vec![-six, -five, -four, -three, six, five, four, three]);
let q = PolynomialCoeffs::new(vec![six, five, four, three, zero, zero, zero, zero]);
let computed_q = a.divide_by_z_h(4);
assert_eq!(computed_q, q);
}
use crate::polynomial::PolynomialCoeffs;
#[test]
#[ignore]

View File

@ -1,2 +1,616 @@
pub(crate) mod division;
pub mod polynomial;
use std::cmp::max;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign};
use anyhow::{ensure, Result};
use serde::{Deserialize, Serialize};
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::fft::{fft, fft_with_options, ifft, FftRootTable};
use crate::field::field_types::Field;
use crate::util::log2_strict;
/// A polynomial in point-value form.
///
/// The points are implicitly `g^i`, where `g` generates the subgroup whose size equals the number
/// of points.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct PolynomialValues<F: Field> {
pub values: Vec<F>,
}
impl<F: Field> PolynomialValues<F> {
pub fn new(values: Vec<F>) -> Self {
PolynomialValues { values }
}
/// The number of values stored.
pub(crate) fn len(&self) -> usize {
self.values.len()
}
pub fn ifft(&self) -> PolynomialCoeffs<F> {
ifft(self)
}
/// Returns the polynomial whose evaluation on the coset `shift*H` is `self`.
pub fn coset_ifft(&self, shift: F) -> PolynomialCoeffs<F> {
let mut shifted_coeffs = self.ifft();
shifted_coeffs
.coeffs
.iter_mut()
.zip(shift.inverse().powers())
.for_each(|(c, r)| {
*c *= r;
});
shifted_coeffs
}
pub fn lde_multiple(polys: Vec<Self>, rate_bits: usize) -> Vec<Self> {
polys.into_iter().map(|p| p.lde(rate_bits)).collect()
}
pub fn lde(&self, rate_bits: usize) -> Self {
let coeffs = ifft(self).lde(rate_bits);
fft_with_options(&coeffs, Some(rate_bits), None)
}
pub fn degree(&self) -> usize {
self.degree_plus_one()
.checked_sub(1)
.expect("deg(0) is undefined")
}
pub fn degree_plus_one(&self) -> usize {
self.ifft().degree_plus_one()
}
}
impl<F: Field> From<Vec<F>> for PolynomialValues<F> {
fn from(values: Vec<F>) -> Self {
Self::new(values)
}
}
/// A polynomial in coefficient form.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound = "")]
pub struct PolynomialCoeffs<F: Field> {
pub(crate) coeffs: Vec<F>,
}
impl<F: Field> PolynomialCoeffs<F> {
pub fn new(coeffs: Vec<F>) -> Self {
PolynomialCoeffs { coeffs }
}
pub(crate) fn empty() -> Self {
Self::new(Vec::new())
}
pub(crate) fn zero(len: usize) -> Self {
Self::new(vec![F::ZERO; len])
}
pub(crate) fn is_zero(&self) -> bool {
self.coeffs.iter().all(|x| x.is_zero())
}
/// The number of coefficients. This does not filter out any zero coefficients, so it is not
/// necessarily related to the degree.
pub fn len(&self) -> usize {
self.coeffs.len()
}
pub fn log_len(&self) -> usize {
log2_strict(self.len())
}
pub(crate) fn chunks(&self, chunk_size: usize) -> Vec<Self> {
self.coeffs
.chunks(chunk_size)
.map(|chunk| PolynomialCoeffs::new(chunk.to_vec()))
.collect()
}
pub fn eval(&self, x: F) -> F {
self.coeffs
.iter()
.rev()
.fold(F::ZERO, |acc, &c| acc * x + c)
}
/// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1.
pub fn eval_with_powers(&self, powers: &[F]) -> F {
debug_assert_eq!(self.coeffs.len(), powers.len() + 1);
let acc = self.coeffs[0];
self.coeffs[1..]
.iter()
.zip(powers)
.fold(acc, |acc, (&x, &c)| acc + c * x)
}
pub fn eval_base<const D: usize>(&self, x: F::BaseField) -> F
where
F: FieldExtension<D>,
{
self.coeffs
.iter()
.rev()
.fold(F::ZERO, |acc, &c| acc.scalar_mul(x) + c)
}
/// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1.
pub fn eval_base_with_powers<const D: usize>(&self, powers: &[F::BaseField]) -> F
where
F: FieldExtension<D>,
{
debug_assert_eq!(self.coeffs.len(), powers.len() + 1);
let acc = self.coeffs[0];
self.coeffs[1..]
.iter()
.zip(powers)
.fold(acc, |acc, (&x, &c)| acc + x.scalar_mul(c))
}
pub fn lde_multiple(polys: Vec<&Self>, rate_bits: usize) -> Vec<Self> {
polys.into_iter().map(|p| p.lde(rate_bits)).collect()
}
pub fn lde(&self, rate_bits: usize) -> Self {
self.padded(self.len() << rate_bits)
}
pub(crate) fn pad(&mut self, new_len: usize) -> Result<()> {
ensure!(
new_len >= self.len(),
"Trying to pad a polynomial of length {} to a length of {}.",
self.len(),
new_len
);
self.coeffs.resize(new_len, F::ZERO);
Ok(())
}
pub(crate) fn padded(&self, new_len: usize) -> Self {
let mut poly = self.clone();
poly.pad(new_len).unwrap();
poly
}
/// Removes leading zero coefficients.
pub fn trim(&mut self) {
self.coeffs.truncate(self.degree_plus_one());
}
/// Removes leading zero coefficients.
pub fn trimmed(&self) -> Self {
let coeffs = self.coeffs[..self.degree_plus_one()].to_vec();
Self { coeffs }
}
/// Degree of the polynomial + 1, or 0 for a polynomial with no non-zero coefficients.
pub(crate) fn degree_plus_one(&self) -> usize {
(0usize..self.len())
.rev()
.find(|&i| self.coeffs[i].is_nonzero())
.map_or(0, |i| i + 1)
}
/// Leading coefficient.
pub fn lead(&self) -> F {
self.coeffs
.iter()
.rev()
.find(|x| x.is_nonzero())
.map_or(F::ZERO, |x| *x)
}
/// Reverse the order of the coefficients, not taking into account the leading zero coefficients.
pub(crate) fn rev(&self) -> Self {
Self::new(self.trimmed().coeffs.into_iter().rev().collect())
}
pub fn fft(&self) -> PolynomialValues<F> {
fft(self)
}
pub fn fft_with_options(
&self,
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) -> PolynomialValues<F> {
fft_with_options(self, zero_factor, root_table)
}
/// Returns the evaluation of the polynomial on the coset `shift*H`.
pub fn coset_fft(&self, shift: F) -> PolynomialValues<F> {
self.coset_fft_with_options(shift, None, None)
}
/// Returns the evaluation of the polynomial on the coset `shift*H`.
pub fn coset_fft_with_options(
&self,
shift: F,
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) -> PolynomialValues<F> {
let modified_poly: Self = shift
.powers()
.zip(&self.coeffs)
.map(|(r, &c)| r * c)
.collect::<Vec<_>>()
.into();
modified_poly.fft_with_options(zero_factor, root_table)
}
pub fn to_extension<const D: usize>(&self) -> PolynomialCoeffs<F::Extension>
where
F: Extendable<D>,
{
PolynomialCoeffs::new(self.coeffs.iter().map(|&c| c.into()).collect())
}
pub fn mul_extension<const D: usize>(&self, rhs: F::Extension) -> PolynomialCoeffs<F::Extension>
where
F: Extendable<D>,
{
PolynomialCoeffs::new(self.coeffs.iter().map(|&c| rhs.scalar_mul(c)).collect())
}
}
impl<F: Field> PartialEq for PolynomialCoeffs<F> {
fn eq(&self, other: &Self) -> bool {
let max_terms = self.coeffs.len().max(other.coeffs.len());
for i in 0..max_terms {
let self_i = self.coeffs.get(i).cloned().unwrap_or(F::ZERO);
let other_i = other.coeffs.get(i).cloned().unwrap_or(F::ZERO);
if self_i != other_i {
return false;
}
}
true
}
}
impl<F: Field> Eq for PolynomialCoeffs<F> {}
impl<F: Field> From<Vec<F>> for PolynomialCoeffs<F> {
fn from(coeffs: Vec<F>) -> Self {
Self::new(coeffs)
}
}
impl<F: Field> Add for &PolynomialCoeffs<F> {
type Output = PolynomialCoeffs<F>;
fn add(self, rhs: Self) -> Self::Output {
let len = max(self.len(), rhs.len());
let a = self.padded(len).coeffs;
let b = rhs.padded(len).coeffs;
let coeffs = a.into_iter().zip(b).map(|(x, y)| x + y).collect();
PolynomialCoeffs::new(coeffs)
}
}
impl<F: Field> Sum for PolynomialCoeffs<F> {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Self::empty(), |acc, p| &acc + &p)
}
}
impl<F: Field> Sub for &PolynomialCoeffs<F> {
type Output = PolynomialCoeffs<F>;
fn sub(self, rhs: Self) -> Self::Output {
let len = max(self.len(), rhs.len());
let mut coeffs = self.padded(len).coeffs;
for (i, &c) in rhs.coeffs.iter().enumerate() {
coeffs[i] -= c;
}
PolynomialCoeffs::new(coeffs)
}
}
impl<F: Field> AddAssign for PolynomialCoeffs<F> {
fn add_assign(&mut self, rhs: Self) {
let len = max(self.len(), rhs.len());
self.coeffs.resize(len, F::ZERO);
for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) {
*l += r;
}
}
}
impl<F: Field> AddAssign<&Self> for PolynomialCoeffs<F> {
fn add_assign(&mut self, rhs: &Self) {
let len = max(self.len(), rhs.len());
self.coeffs.resize(len, F::ZERO);
for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) {
*l += r;
}
}
}
impl<F: Field> SubAssign for PolynomialCoeffs<F> {
fn sub_assign(&mut self, rhs: Self) {
let len = max(self.len(), rhs.len());
self.coeffs.resize(len, F::ZERO);
for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) {
*l -= r;
}
}
}
impl<F: Field> SubAssign<&Self> for PolynomialCoeffs<F> {
fn sub_assign(&mut self, rhs: &Self) {
let len = max(self.len(), rhs.len());
self.coeffs.resize(len, F::ZERO);
for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) {
*l -= r;
}
}
}
impl<F: Field> Mul<F> for &PolynomialCoeffs<F> {
type Output = PolynomialCoeffs<F>;
fn mul(self, rhs: F) -> Self::Output {
let coeffs = self.coeffs.iter().map(|&x| rhs * x).collect();
PolynomialCoeffs::new(coeffs)
}
}
impl<F: Field> MulAssign<F> for PolynomialCoeffs<F> {
fn mul_assign(&mut self, rhs: F) {
self.coeffs.iter_mut().for_each(|x| *x *= rhs);
}
}
impl<F: Field> Mul for &PolynomialCoeffs<F> {
type Output = PolynomialCoeffs<F>;
#[allow(clippy::suspicious_arithmetic_impl)]
fn mul(self, rhs: Self) -> Self::Output {
let new_len = (self.len() + rhs.len()).next_power_of_two();
let a = self.padded(new_len);
let b = rhs.padded(new_len);
let a_evals = a.fft();
let b_evals = b.fft();
let mul_evals: Vec<F> = a_evals
.values
.into_iter()
.zip(b_evals.values)
.map(|(pa, pb)| pa * pb)
.collect();
ifft(&mul_evals.into())
}
}
#[cfg(test)]
mod tests {
use std::time::Instant;
use rand::{thread_rng, Rng};
use super::*;
use crate::field::goldilocks_field::GoldilocksField;
#[test]
fn test_trimmed() {
type F = GoldilocksField;
assert_eq!(
PolynomialCoeffs::<F> { coeffs: vec![] }.trimmed(),
PolynomialCoeffs::<F> { coeffs: vec![] }
);
assert_eq!(
PolynomialCoeffs::<F> {
coeffs: vec![F::ZERO]
}
.trimmed(),
PolynomialCoeffs::<F> { coeffs: vec![] }
);
assert_eq!(
PolynomialCoeffs::<F> {
coeffs: vec![F::ONE, F::TWO, F::ZERO, F::ZERO]
}
.trimmed(),
PolynomialCoeffs::<F> {
coeffs: vec![F::ONE, F::TWO]
}
);
}
#[test]
fn test_coset_fft() {
type F = GoldilocksField;
let k = 8;
let n = 1 << k;
let poly = PolynomialCoeffs::new(F::rand_vec(n));
let shift = F::rand();
let coset_evals = poly.coset_fft(shift).values;
let generator = F::primitive_root_of_unity(k);
let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n)
.into_iter()
.map(|x| poly.eval(x))
.collect::<Vec<_>>();
assert_eq!(coset_evals, naive_coset_evals);
let ifft_coeffs = PolynomialValues::new(coset_evals).coset_ifft(shift);
assert_eq!(poly, ifft_coeffs);
}
#[test]
fn test_coset_ifft() {
type F = GoldilocksField;
let k = 8;
let n = 1 << k;
let evals = PolynomialValues::new(F::rand_vec(n));
let shift = F::rand();
let coeffs = evals.coset_ifft(shift);
let generator = F::primitive_root_of_unity(k);
let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n)
.into_iter()
.map(|x| coeffs.eval(x))
.collect::<Vec<_>>();
assert_eq!(evals, naive_coset_evals.into());
let fft_evals = coeffs.coset_fft(shift);
assert_eq!(evals, fft_evals);
}
#[test]
fn test_polynomial_multiplication() {
type F = GoldilocksField;
let mut rng = thread_rng();
let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000));
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = PolynomialCoeffs::new(F::rand_vec(b_deg));
let m1 = &a * &b;
let m2 = &a * &b;
for _ in 0..1000 {
let x = F::rand();
assert_eq!(m1.eval(x), a.eval(x) * b.eval(x));
assert_eq!(m2.eval(x), a.eval(x) * b.eval(x));
}
}
#[test]
fn test_inv_mod_xn() {
type F = GoldilocksField;
let mut rng = thread_rng();
let a_deg = rng.gen_range(1..1_000);
let n = rng.gen_range(1..1_000);
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = a.inv_mod_xn(n);
let mut m = &a * &b;
m.coeffs.drain(n..);
m.trim();
assert_eq!(
m,
PolynomialCoeffs::new(vec![F::ONE]),
"a: {:#?}, b:{:#?}, n:{:#?}, m:{:#?}",
a,
b,
n,
m
);
}
#[test]
fn test_polynomial_long_division() {
type F = GoldilocksField;
let mut rng = thread_rng();
let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000));
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = PolynomialCoeffs::new(F::rand_vec(b_deg));
let (q, r) = a.div_rem_long_division(&b);
for _ in 0..1000 {
let x = F::rand();
assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x));
}
}
#[test]
fn test_polynomial_division() {
type F = GoldilocksField;
let mut rng = thread_rng();
let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000));
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = PolynomialCoeffs::new(F::rand_vec(b_deg));
let (q, r) = a.div_rem(&b);
for _ in 0..1000 {
let x = F::rand();
assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x));
}
}
#[test]
fn test_polynomial_division_by_constant() {
type F = GoldilocksField;
let mut rng = thread_rng();
let a_deg = rng.gen_range(1..10_000);
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = PolynomialCoeffs::from(vec![F::rand()]);
let (q, r) = a.div_rem(&b);
for _ in 0..1000 {
let x = F::rand();
assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x));
}
}
// Test to see which polynomial division method is faster for divisions of the type
// `(X^n - 1)/(X - a)
#[test]
fn test_division_linear() {
type F = GoldilocksField;
let mut rng = thread_rng();
let l = 14;
let n = 1 << l;
let g = F::primitive_root_of_unity(l);
let xn_minus_one = {
let mut xn_min_one_vec = vec![F::ZERO; n + 1];
xn_min_one_vec[n] = F::ONE;
xn_min_one_vec[0] = F::NEG_ONE;
PolynomialCoeffs::new(xn_min_one_vec)
};
let a = g.exp_u64(rng.gen_range(0..(n as u64)));
let denom = PolynomialCoeffs::new(vec![-a, F::ONE]);
let now = Instant::now();
xn_minus_one.div_rem(&denom);
println!("Division time: {:?}", now.elapsed());
let now = Instant::now();
xn_minus_one.div_rem_long_division(&denom);
println!("Division time: {:?}", now.elapsed());
}
#[test]
fn eq() {
type F = GoldilocksField;
assert_eq!(
PolynomialCoeffs::<F>::new(vec![]),
PolynomialCoeffs::new(vec![])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![F::ZERO])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![]),
PolynomialCoeffs::new(vec![F::ZERO])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![F::ZERO, F::ZERO])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![F::ONE]),
PolynomialCoeffs::new(vec![F::ONE, F::ZERO])
);
assert_ne!(
PolynomialCoeffs::<F>::new(vec![]),
PolynomialCoeffs::new(vec![F::ONE])
);
assert_ne!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![F::ZERO, F::ONE])
);
assert_ne!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![F::ONE, F::ZERO])
);
}
}

View File

@ -1,635 +0,0 @@
use std::cmp::max;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign};
use anyhow::{ensure, Result};
use serde::{Deserialize, Serialize};
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::fft::{fft, fft_with_options, ifft, FftRootTable};
use crate::field::field_types::Field;
use crate::util::log2_strict;
/// A polynomial in point-value form.
///
/// The points are implicitly `g^i`, where `g` generates the subgroup whose size equals the number
/// of points.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct PolynomialValues<F: Field> {
pub values: Vec<F>,
}
impl<F: Field> PolynomialValues<F> {
pub fn new(values: Vec<F>) -> Self {
PolynomialValues { values }
}
pub(crate) fn zero(len: usize) -> Self {
Self::new(vec![F::ZERO; len])
}
/// The number of values stored.
pub(crate) fn len(&self) -> usize {
self.values.len()
}
pub fn ifft(&self) -> PolynomialCoeffs<F> {
ifft(self)
}
/// Returns the polynomial whose evaluation on the coset `shift*H` is `self`.
pub fn coset_ifft(&self, shift: F) -> PolynomialCoeffs<F> {
let mut shifted_coeffs = self.ifft();
shifted_coeffs
.coeffs
.iter_mut()
.zip(shift.inverse().powers())
.for_each(|(c, r)| {
*c *= r;
});
shifted_coeffs
}
pub fn lde_multiple(polys: Vec<Self>, rate_bits: usize) -> Vec<Self> {
polys.into_iter().map(|p| p.lde(rate_bits)).collect()
}
pub fn lde(&self, rate_bits: usize) -> Self {
let coeffs = ifft(self).lde(rate_bits);
fft_with_options(&coeffs, Some(rate_bits), None)
}
pub fn degree(&self) -> usize {
self.degree_plus_one()
.checked_sub(1)
.expect("deg(0) is undefined")
}
pub fn degree_plus_one(&self) -> usize {
self.ifft().degree_plus_one()
}
}
impl<F: Field> From<Vec<F>> for PolynomialValues<F> {
fn from(values: Vec<F>) -> Self {
Self::new(values)
}
}
/// A polynomial in coefficient form.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound = "")]
pub struct PolynomialCoeffs<F: Field> {
pub(crate) coeffs: Vec<F>,
}
impl<F: Field> PolynomialCoeffs<F> {
pub fn new(coeffs: Vec<F>) -> Self {
PolynomialCoeffs { coeffs }
}
/// Create a new polynomial with its coefficient list padded to the next power of two.
pub(crate) fn new_padded(mut coeffs: Vec<F>) -> Self {
while !coeffs.len().is_power_of_two() {
coeffs.push(F::ZERO);
}
PolynomialCoeffs { coeffs }
}
pub(crate) fn empty() -> Self {
Self::new(Vec::new())
}
pub(crate) fn zero(len: usize) -> Self {
Self::new(vec![F::ZERO; len])
}
pub(crate) fn one() -> Self {
Self::new(vec![F::ONE])
}
pub(crate) fn is_zero(&self) -> bool {
self.coeffs.iter().all(|x| x.is_zero())
}
/// The number of coefficients. This does not filter out any zero coefficients, so it is not
/// necessarily related to the degree.
pub fn len(&self) -> usize {
self.coeffs.len()
}
pub fn log_len(&self) -> usize {
log2_strict(self.len())
}
pub(crate) fn chunks(&self, chunk_size: usize) -> Vec<Self> {
self.coeffs
.chunks(chunk_size)
.map(|chunk| PolynomialCoeffs::new(chunk.to_vec()))
.collect()
}
pub fn eval(&self, x: F) -> F {
self.coeffs
.iter()
.rev()
.fold(F::ZERO, |acc, &c| acc * x + c)
}
pub fn eval_base<const D: usize>(&self, x: F::BaseField) -> F
where
F: FieldExtension<D>,
{
self.coeffs
.iter()
.rev()
.fold(F::ZERO, |acc, &c| acc.scalar_mul(x) + c)
}
pub fn lde_multiple(polys: Vec<&Self>, rate_bits: usize) -> Vec<Self> {
polys.into_iter().map(|p| p.lde(rate_bits)).collect()
}
pub fn lde(&self, rate_bits: usize) -> Self {
self.padded(self.len() << rate_bits)
}
pub(crate) fn pad(&mut self, new_len: usize) -> Result<()> {
ensure!(
new_len >= self.len(),
"Trying to pad a polynomial of length {} to a length of {}.",
self.len(),
new_len
);
self.coeffs.resize(new_len, F::ZERO);
Ok(())
}
pub(crate) fn padded(&self, new_len: usize) -> Self {
let mut poly = self.clone();
poly.pad(new_len).unwrap();
poly
}
/// Removes leading zero coefficients.
pub fn trim(&mut self) {
self.coeffs.truncate(self.degree_plus_one());
}
/// Removes leading zero coefficients.
pub fn trimmed(&self) -> Self {
let coeffs = self.coeffs[..self.degree_plus_one()].to_vec();
Self { coeffs }
}
/// Degree of the polynomial + 1, or 0 for a polynomial with no non-zero coefficients.
pub(crate) fn degree_plus_one(&self) -> usize {
(0usize..self.len())
.rev()
.find(|&i| self.coeffs[i].is_nonzero())
.map_or(0, |i| i + 1)
}
/// Leading coefficient.
pub fn lead(&self) -> F {
self.coeffs
.iter()
.rev()
.find(|x| x.is_nonzero())
.map_or(F::ZERO, |x| *x)
}
/// Reverse the order of the coefficients, not taking into account the leading zero coefficients.
pub(crate) fn rev(&self) -> Self {
Self::new(self.trimmed().coeffs.into_iter().rev().collect())
}
pub fn fft(&self) -> PolynomialValues<F> {
fft(self)
}
pub fn fft_with_options(
&self,
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) -> PolynomialValues<F> {
fft_with_options(self, zero_factor, root_table)
}
/// Returns the evaluation of the polynomial on the coset `shift*H`.
pub fn coset_fft(&self, shift: F) -> PolynomialValues<F> {
self.coset_fft_with_options(shift, None, None)
}
/// Returns the evaluation of the polynomial on the coset `shift*H`.
pub fn coset_fft_with_options(
&self,
shift: F,
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) -> PolynomialValues<F> {
let modified_poly: Self = shift
.powers()
.zip(&self.coeffs)
.map(|(r, &c)| r * c)
.collect::<Vec<_>>()
.into();
modified_poly.fft_with_options(zero_factor, root_table)
}
pub fn to_extension<const D: usize>(&self) -> PolynomialCoeffs<F::Extension>
where
F: Extendable<D>,
{
PolynomialCoeffs::new(self.coeffs.iter().map(|&c| c.into()).collect())
}
pub fn mul_extension<const D: usize>(&self, rhs: F::Extension) -> PolynomialCoeffs<F::Extension>
where
F: Extendable<D>,
{
PolynomialCoeffs::new(self.coeffs.iter().map(|&c| rhs.scalar_mul(c)).collect())
}
}
impl<F: Field> PartialEq for PolynomialCoeffs<F> {
fn eq(&self, other: &Self) -> bool {
let max_terms = self.coeffs.len().max(other.coeffs.len());
for i in 0..max_terms {
let self_i = self.coeffs.get(i).cloned().unwrap_or(F::ZERO);
let other_i = other.coeffs.get(i).cloned().unwrap_or(F::ZERO);
if self_i != other_i {
return false;
}
}
true
}
}
impl<F: Field> Eq for PolynomialCoeffs<F> {}
impl<F: Field> From<Vec<F>> for PolynomialCoeffs<F> {
fn from(coeffs: Vec<F>) -> Self {
Self::new(coeffs)
}
}
impl<F: Field> Add for &PolynomialCoeffs<F> {
type Output = PolynomialCoeffs<F>;
fn add(self, rhs: Self) -> Self::Output {
let len = max(self.len(), rhs.len());
let a = self.padded(len).coeffs;
let b = rhs.padded(len).coeffs;
let coeffs = a.into_iter().zip(b).map(|(x, y)| x + y).collect();
PolynomialCoeffs::new(coeffs)
}
}
impl<F: Field> Sum for PolynomialCoeffs<F> {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Self::empty(), |acc, p| &acc + &p)
}
}
impl<F: Field> Sub for &PolynomialCoeffs<F> {
type Output = PolynomialCoeffs<F>;
fn sub(self, rhs: Self) -> Self::Output {
let len = max(self.len(), rhs.len());
let mut coeffs = self.padded(len).coeffs;
for (i, &c) in rhs.coeffs.iter().enumerate() {
coeffs[i] -= c;
}
PolynomialCoeffs::new(coeffs)
}
}
impl<F: Field> AddAssign for PolynomialCoeffs<F> {
fn add_assign(&mut self, rhs: Self) {
let len = max(self.len(), rhs.len());
self.coeffs.resize(len, F::ZERO);
for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) {
*l += r;
}
}
}
impl<F: Field> AddAssign<&Self> for PolynomialCoeffs<F> {
fn add_assign(&mut self, rhs: &Self) {
let len = max(self.len(), rhs.len());
self.coeffs.resize(len, F::ZERO);
for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) {
*l += r;
}
}
}
impl<F: Field> SubAssign for PolynomialCoeffs<F> {
fn sub_assign(&mut self, rhs: Self) {
let len = max(self.len(), rhs.len());
self.coeffs.resize(len, F::ZERO);
for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) {
*l -= r;
}
}
}
impl<F: Field> SubAssign<&Self> for PolynomialCoeffs<F> {
fn sub_assign(&mut self, rhs: &Self) {
let len = max(self.len(), rhs.len());
self.coeffs.resize(len, F::ZERO);
for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) {
*l -= r;
}
}
}
impl<F: Field> Mul<F> for &PolynomialCoeffs<F> {
type Output = PolynomialCoeffs<F>;
fn mul(self, rhs: F) -> Self::Output {
let coeffs = self.coeffs.iter().map(|&x| rhs * x).collect();
PolynomialCoeffs::new(coeffs)
}
}
impl<F: Field> MulAssign<F> for PolynomialCoeffs<F> {
fn mul_assign(&mut self, rhs: F) {
self.coeffs.iter_mut().for_each(|x| *x *= rhs);
}
}
impl<F: Field> Mul for &PolynomialCoeffs<F> {
type Output = PolynomialCoeffs<F>;
#[allow(clippy::suspicious_arithmetic_impl)]
fn mul(self, rhs: Self) -> Self::Output {
let new_len = (self.len() + rhs.len()).next_power_of_two();
let a = self.padded(new_len);
let b = rhs.padded(new_len);
let a_evals = a.fft();
let b_evals = b.fft();
let mul_evals: Vec<F> = a_evals
.values
.into_iter()
.zip(b_evals.values)
.map(|(pa, pb)| pa * pb)
.collect();
ifft(&mul_evals.into())
}
}
#[cfg(test)]
mod tests {
use std::time::Instant;
use rand::{thread_rng, Rng};
use super::*;
use crate::field::goldilocks_field::GoldilocksField;
#[test]
fn test_trimmed() {
type F = GoldilocksField;
assert_eq!(
PolynomialCoeffs::<F> { coeffs: vec![] }.trimmed(),
PolynomialCoeffs::<F> { coeffs: vec![] }
);
assert_eq!(
PolynomialCoeffs::<F> {
coeffs: vec![F::ZERO]
}
.trimmed(),
PolynomialCoeffs::<F> { coeffs: vec![] }
);
assert_eq!(
PolynomialCoeffs::<F> {
coeffs: vec![F::ONE, F::TWO, F::ZERO, F::ZERO]
}
.trimmed(),
PolynomialCoeffs::<F> {
coeffs: vec![F::ONE, F::TWO]
}
);
}
#[test]
fn test_coset_fft() {
type F = GoldilocksField;
let k = 8;
let n = 1 << k;
let poly = PolynomialCoeffs::new(F::rand_vec(n));
let shift = F::rand();
let coset_evals = poly.coset_fft(shift).values;
let generator = F::primitive_root_of_unity(k);
let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n)
.into_iter()
.map(|x| poly.eval(x))
.collect::<Vec<_>>();
assert_eq!(coset_evals, naive_coset_evals);
let ifft_coeffs = PolynomialValues::new(coset_evals).coset_ifft(shift);
assert_eq!(poly, ifft_coeffs.into());
}
#[test]
fn test_coset_ifft() {
type F = GoldilocksField;
let k = 8;
let n = 1 << k;
let evals = PolynomialValues::new(F::rand_vec(n));
let shift = F::rand();
let coeffs = evals.coset_ifft(shift);
let generator = F::primitive_root_of_unity(k);
let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n)
.into_iter()
.map(|x| coeffs.eval(x))
.collect::<Vec<_>>();
assert_eq!(evals, naive_coset_evals.into());
let fft_evals = coeffs.coset_fft(shift);
assert_eq!(evals, fft_evals);
}
#[test]
fn test_polynomial_multiplication() {
type F = GoldilocksField;
let mut rng = thread_rng();
let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000));
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = PolynomialCoeffs::new(F::rand_vec(b_deg));
let m1 = &a * &b;
let m2 = &a * &b;
for _ in 0..1000 {
let x = F::rand();
assert_eq!(m1.eval(x), a.eval(x) * b.eval(x));
assert_eq!(m2.eval(x), a.eval(x) * b.eval(x));
}
}
#[test]
fn test_inv_mod_xn() {
type F = GoldilocksField;
let mut rng = thread_rng();
let a_deg = rng.gen_range(1..1_000);
let n = rng.gen_range(1..1_000);
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = a.inv_mod_xn(n);
let mut m = &a * &b;
m.coeffs.drain(n..);
m.trim();
assert_eq!(
m,
PolynomialCoeffs::new(vec![F::ONE]),
"a: {:#?}, b:{:#?}, n:{:#?}, m:{:#?}",
a,
b,
n,
m
);
}
#[test]
fn test_polynomial_long_division() {
type F = GoldilocksField;
let mut rng = thread_rng();
let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000));
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = PolynomialCoeffs::new(F::rand_vec(b_deg));
let (q, r) = a.div_rem_long_division(&b);
for _ in 0..1000 {
let x = F::rand();
assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x));
}
}
#[test]
fn test_polynomial_division() {
type F = GoldilocksField;
let mut rng = thread_rng();
let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000));
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = PolynomialCoeffs::new(F::rand_vec(b_deg));
let (q, r) = a.div_rem(&b);
for _ in 0..1000 {
let x = F::rand();
assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x));
}
}
#[test]
fn test_polynomial_division_by_constant() {
type F = GoldilocksField;
let mut rng = thread_rng();
let a_deg = rng.gen_range(1..10_000);
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
let b = PolynomialCoeffs::from(vec![F::rand()]);
let (q, r) = a.div_rem(&b);
for _ in 0..1000 {
let x = F::rand();
assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x));
}
}
#[test]
fn test_division_by_z_h() {
type F = GoldilocksField;
let mut rng = thread_rng();
let a_deg = rng.gen_range(1..10_000);
let n = rng.gen_range(1..a_deg);
let mut a = PolynomialCoeffs::new(F::rand_vec(a_deg));
a.trim();
let z_h = {
let mut z_h_vec = vec![F::ZERO; n + 1];
z_h_vec[n] = F::ONE;
z_h_vec[0] = F::NEG_ONE;
PolynomialCoeffs::new(z_h_vec)
};
let m = &a * &z_h;
let now = Instant::now();
let mut a_test = m.divide_by_z_h(n);
a_test.trim();
println!("Division time: {:?}", now.elapsed());
assert_eq!(a, a_test);
}
#[test]
fn divide_zero_poly_by_z_h() {
let zero_poly = PolynomialCoeffs::<GoldilocksField>::empty();
zero_poly.divide_by_z_h(16);
}
// Test to see which polynomial division method is faster for divisions of the type
// `(X^n - 1)/(X - a)
#[test]
fn test_division_linear() {
type F = GoldilocksField;
let mut rng = thread_rng();
let l = 14;
let n = 1 << l;
let g = F::primitive_root_of_unity(l);
let xn_minus_one = {
let mut xn_min_one_vec = vec![F::ZERO; n + 1];
xn_min_one_vec[n] = F::ONE;
xn_min_one_vec[0] = F::NEG_ONE;
PolynomialCoeffs::new(xn_min_one_vec)
};
let a = g.exp_u64(rng.gen_range(0..(n as u64)));
let denom = PolynomialCoeffs::new(vec![-a, F::ONE]);
let now = Instant::now();
xn_minus_one.div_rem(&denom);
println!("Division time: {:?}", now.elapsed());
let now = Instant::now();
xn_minus_one.div_rem_long_division(&denom);
println!("Division time: {:?}", now.elapsed());
}
#[test]
fn eq() {
type F = GoldilocksField;
assert_eq!(
PolynomialCoeffs::<F>::new(vec![]),
PolynomialCoeffs::new(vec![])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![F::ZERO])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![]),
PolynomialCoeffs::new(vec![F::ZERO])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![F::ZERO, F::ZERO])
);
assert_eq!(
PolynomialCoeffs::<F>::new(vec![F::ONE]),
PolynomialCoeffs::new(vec![F::ONE, F::ZERO])
);
assert_ne!(
PolynomialCoeffs::<F>::new(vec![]),
PolynomialCoeffs::new(vec![F::ONE])
);
assert_ne!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![F::ZERO, F::ONE])
);
assert_ne!(
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
PolynomialCoeffs::new(vec![F::ONE, F::ZERO])
);
}
}

Some files were not shown because too many files have changed in this diff Show More