mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-02 13:53:07 +00:00
Merge branch 'main' into generic_configuration
# Conflicts: # src/field/extension_field/mod.rs # src/fri/recursive_verifier.rs # src/gadgets/arithmetic.rs # src/gadgets/arithmetic_extension.rs # src/gadgets/hash.rs # src/gadgets/interpolation.rs # src/gadgets/random_access.rs # src/gadgets/sorting.rs # src/gates/arithmetic_u32.rs # src/gates/gate_tree.rs # src/gates/interpolation.rs # src/gates/poseidon.rs # src/gates/poseidon_mds.rs # src/gates/random_access.rs # src/hash/hashing.rs # src/hash/merkle_proofs.rs # src/hash/poseidon.rs # src/iop/challenger.rs # src/iop/generator.rs # src/iop/witness.rs # src/plonk/circuit_data.rs # src/plonk/proof.rs # src/plonk/prover.rs # src/plonk/recursive_verifier.rs # src/util/partial_products.rs # src/util/reducing.rs
This commit is contained in:
commit
bdbc8b6931
@ -28,9 +28,11 @@ jobs:
|
||||
with:
|
||||
command: test
|
||||
args: --all
|
||||
env:
|
||||
RUSTFLAGS: -Copt-level=3 -Cdebug-assertions -Coverflow-checks=y
|
||||
|
||||
lints:
|
||||
name: Formatting
|
||||
name: Formatting and Clippy
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(toJSON(github.event.commits.*.message), '[skip-ci]')"
|
||||
steps:
|
||||
@ -43,10 +45,17 @@ jobs:
|
||||
profile: minimal
|
||||
toolchain: nightly
|
||||
override: true
|
||||
components: rustfmt
|
||||
components: rustfmt, clippy
|
||||
|
||||
- name: Run cargo fmt
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: fmt
|
||||
args: --all -- --check
|
||||
|
||||
- name: Run cargo clippy
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: clippy
|
||||
args: --all-features --all-targets -- -D warnings -A incomplete-features
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ license = "MIT OR Apache-2.0"
|
||||
repository = "https://github.com/mir-protocol/plonky2"
|
||||
keywords = ["cryptography", "SNARK", "FRI"]
|
||||
categories = ["cryptography"]
|
||||
edition = "2018"
|
||||
edition = "2021"
|
||||
default-run = "bench_recursion"
|
||||
|
||||
[dependencies]
|
||||
@ -28,6 +28,9 @@ serde_cbor = "0.11.1"
|
||||
keccak-hash = "0.8.0"
|
||||
static_assertions = "1.1.0"
|
||||
|
||||
[target.'cfg(not(target_env = "msvc"))'.dependencies]
|
||||
jemallocator = "0.3.2"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.3.5"
|
||||
tynm = "0.1.6"
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
use plonky2::field::field_types::Field;
|
||||
use plonky2::field::goldilocks_field::GoldilocksField;
|
||||
use plonky2::polynomial::polynomial::PolynomialCoeffs;
|
||||
use plonky2::polynomial::PolynomialCoeffs;
|
||||
use tynm::type_name;
|
||||
|
||||
pub(crate) fn bench_ffts<F: Field>(c: &mut Criterion) {
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
#![feature(destructuring_assignment)]
|
||||
|
||||
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
|
||||
use plonky2::field::extension_field::quartic::QuarticExtension;
|
||||
use plonky2::field::field_types::Field;
|
||||
@ -112,6 +110,66 @@ pub(crate) fn bench_field<F: Field>(c: &mut Criterion) {
|
||||
c.bench_function(&format!("try_inverse<{}>", type_name::<F>()), |b| {
|
||||
b.iter_batched(|| F::rand(), |x| x.try_inverse(), BatchSize::SmallInput)
|
||||
});
|
||||
|
||||
c.bench_function(
|
||||
&format!("batch_multiplicative_inverse-tiny<{}>", type_name::<F>()),
|
||||
|b| {
|
||||
b.iter_batched(
|
||||
|| (0..2).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|
||||
|x| F::batch_multiplicative_inverse(&x),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
c.bench_function(
|
||||
&format!("batch_multiplicative_inverse-small<{}>", type_name::<F>()),
|
||||
|b| {
|
||||
b.iter_batched(
|
||||
|| (0..4).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|
||||
|x| F::batch_multiplicative_inverse(&x),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
c.bench_function(
|
||||
&format!("batch_multiplicative_inverse-medium<{}>", type_name::<F>()),
|
||||
|b| {
|
||||
b.iter_batched(
|
||||
|| (0..16).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|
||||
|x| F::batch_multiplicative_inverse(&x),
|
||||
BatchSize::SmallInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
c.bench_function(
|
||||
&format!("batch_multiplicative_inverse-large<{}>", type_name::<F>()),
|
||||
|b| {
|
||||
b.iter_batched(
|
||||
|| (0..256).into_iter().map(|_| F::rand()).collect::<Vec<_>>(),
|
||||
|x| F::batch_multiplicative_inverse(&x),
|
||||
BatchSize::LargeInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
c.bench_function(
|
||||
&format!("batch_multiplicative_inverse-huge<{}>", type_name::<F>()),
|
||||
|b| {
|
||||
b.iter_batched(
|
||||
|| {
|
||||
(0..65536)
|
||||
.into_iter()
|
||||
.map(|_| F::rand())
|
||||
.collect::<Vec<_>>()
|
||||
},
|
||||
|x| F::batch_multiplicative_inverse(&x),
|
||||
BatchSize::LargeInput,
|
||||
)
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
#![feature(destructuring_assignment)]
|
||||
#![feature(generic_const_exprs)]
|
||||
|
||||
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
|
||||
@ -19,7 +18,7 @@ pub(crate) fn bench_gmimc<F: GMiMC<WIDTH>, const WIDTH: usize>(c: &mut Criterion
|
||||
|
||||
pub(crate) fn bench_poseidon<F: Poseidon<WIDTH>, const WIDTH: usize>(c: &mut Criterion)
|
||||
where
|
||||
[(); WIDTH - 1]: ,
|
||||
[(); WIDTH - 1]:,
|
||||
{
|
||||
c.bench_function(&format!("poseidon<{}, {}>", type_name::<F>(), WIDTH), |b| {
|
||||
b.iter_batched(
|
||||
|
||||
@ -2,7 +2,7 @@ use std::time::Instant;
|
||||
|
||||
use plonky2::field::field_types::Field;
|
||||
use plonky2::field::goldilocks_field::GoldilocksField;
|
||||
use plonky2::polynomial::polynomial::PolynomialValues;
|
||||
use plonky2::polynomial::PolynomialValues;
|
||||
use rayon::prelude::*;
|
||||
|
||||
type F = GoldilocksField;
|
||||
|
||||
@ -24,6 +24,7 @@ fn bench_prove<C: GenericConfig<D>, const D: usize>() -> Result<()> {
|
||||
num_wires: 126,
|
||||
num_routed_wires: 33,
|
||||
constant_gate_size: 6,
|
||||
use_base_arithmetic_gate: false,
|
||||
security_bits: 128,
|
||||
rate_bits: 3,
|
||||
num_challenges: 3,
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
//! Generates random constants using ChaCha20, seeded with zero.
|
||||
|
||||
#![allow(clippy::needless_range_loop)]
|
||||
|
||||
use plonky2::field::field_types::PrimeField;
|
||||
use plonky2::field::goldilocks_field::GoldilocksField;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
156
src/curve/curve_adds.rs
Normal file
156
src/curve/curve_adds.rs
Normal file
@ -0,0 +1,156 @@
|
||||
use std::ops::Add;
|
||||
|
||||
use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint};
|
||||
use crate::field::field_types::Field;
|
||||
|
||||
impl<C: Curve> Add<ProjectivePoint<C>> for ProjectivePoint<C> {
|
||||
type Output = ProjectivePoint<C>;
|
||||
|
||||
fn add(self, rhs: ProjectivePoint<C>) -> Self::Output {
|
||||
let ProjectivePoint {
|
||||
x: x1,
|
||||
y: y1,
|
||||
z: z1,
|
||||
} = self;
|
||||
let ProjectivePoint {
|
||||
x: x2,
|
||||
y: y2,
|
||||
z: z2,
|
||||
} = rhs;
|
||||
|
||||
if z1 == C::BaseField::ZERO {
|
||||
return rhs;
|
||||
}
|
||||
if z2 == C::BaseField::ZERO {
|
||||
return self;
|
||||
}
|
||||
|
||||
let x1z2 = x1 * z2;
|
||||
let y1z2 = y1 * z2;
|
||||
let x2z1 = x2 * z1;
|
||||
let y2z1 = y2 * z1;
|
||||
|
||||
// Check if we're doubling or adding inverses.
|
||||
if x1z2 == x2z1 {
|
||||
if y1z2 == y2z1 {
|
||||
// TODO: inline to avoid redundant muls.
|
||||
return self.double();
|
||||
}
|
||||
if y1z2 == -y2z1 {
|
||||
return ProjectivePoint::ZERO;
|
||||
}
|
||||
}
|
||||
|
||||
// From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/add-1998-cmo-2
|
||||
let z1z2 = z1 * z2;
|
||||
let u = y2z1 - y1z2;
|
||||
let uu = u.square();
|
||||
let v = x2z1 - x1z2;
|
||||
let vv = v.square();
|
||||
let vvv = v * vv;
|
||||
let r = vv * x1z2;
|
||||
let a = uu * z1z2 - vvv - r.double();
|
||||
let x3 = v * a;
|
||||
let y3 = u * (r - a) - vvv * y1z2;
|
||||
let z3 = vvv * z1z2;
|
||||
ProjectivePoint::nonzero(x3, y3, z3)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Curve> Add<AffinePoint<C>> for ProjectivePoint<C> {
|
||||
type Output = ProjectivePoint<C>;
|
||||
|
||||
fn add(self, rhs: AffinePoint<C>) -> Self::Output {
|
||||
let ProjectivePoint {
|
||||
x: x1,
|
||||
y: y1,
|
||||
z: z1,
|
||||
} = self;
|
||||
let AffinePoint {
|
||||
x: x2,
|
||||
y: y2,
|
||||
zero: zero2,
|
||||
} = rhs;
|
||||
|
||||
if z1 == C::BaseField::ZERO {
|
||||
return rhs.to_projective();
|
||||
}
|
||||
if zero2 {
|
||||
return self;
|
||||
}
|
||||
|
||||
let x2z1 = x2 * z1;
|
||||
let y2z1 = y2 * z1;
|
||||
|
||||
// Check if we're doubling or adding inverses.
|
||||
if x1 == x2z1 {
|
||||
if y1 == y2z1 {
|
||||
// TODO: inline to avoid redundant muls.
|
||||
return self.double();
|
||||
}
|
||||
if y1 == -y2z1 {
|
||||
return ProjectivePoint::ZERO;
|
||||
}
|
||||
}
|
||||
|
||||
// From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/madd-1998-cmo
|
||||
let u = y2z1 - y1;
|
||||
let uu = u.square();
|
||||
let v = x2z1 - x1;
|
||||
let vv = v.square();
|
||||
let vvv = v * vv;
|
||||
let r = vv * x1;
|
||||
let a = uu * z1 - vvv - r.double();
|
||||
let x3 = v * a;
|
||||
let y3 = u * (r - a) - vvv * y1;
|
||||
let z3 = vvv * z1;
|
||||
ProjectivePoint::nonzero(x3, y3, z3)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Curve> Add<AffinePoint<C>> for AffinePoint<C> {
|
||||
type Output = ProjectivePoint<C>;
|
||||
|
||||
fn add(self, rhs: AffinePoint<C>) -> Self::Output {
|
||||
let AffinePoint {
|
||||
x: x1,
|
||||
y: y1,
|
||||
zero: zero1,
|
||||
} = self;
|
||||
let AffinePoint {
|
||||
x: x2,
|
||||
y: y2,
|
||||
zero: zero2,
|
||||
} = rhs;
|
||||
|
||||
if zero1 {
|
||||
return rhs.to_projective();
|
||||
}
|
||||
if zero2 {
|
||||
return self.to_projective();
|
||||
}
|
||||
|
||||
// Check if we're doubling or adding inverses.
|
||||
if x1 == x2 {
|
||||
if y1 == y2 {
|
||||
return self.to_projective().double();
|
||||
}
|
||||
if y1 == -y2 {
|
||||
return ProjectivePoint::ZERO;
|
||||
}
|
||||
}
|
||||
|
||||
// From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/mmadd-1998-cmo
|
||||
let u = y2 - y1;
|
||||
let uu = u.square();
|
||||
let v = x2 - x1;
|
||||
let vv = v.square();
|
||||
let vvv = v * vv;
|
||||
let r = vv * x1;
|
||||
let a = uu - vvv - r.double();
|
||||
let x3 = v * a;
|
||||
let y3 = u * (r - a) - vvv * y1;
|
||||
let z3 = vvv;
|
||||
ProjectivePoint::nonzero(x3, y3, z3)
|
||||
}
|
||||
}
|
||||
263
src/curve/curve_msm.rs
Normal file
263
src/curve/curve_msm.rs
Normal file
@ -0,0 +1,263 @@
|
||||
use itertools::Itertools;
|
||||
use rayon::prelude::*;
|
||||
|
||||
use crate::curve::curve_summation::affine_multisummation_best;
|
||||
use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint};
|
||||
use crate::field::field_types::Field;
|
||||
|
||||
/// In Yao's method, we compute an affine summation for each digit. In a parallel setting, it would
|
||||
/// be easiest to assign individual summations to threads, but this would be sub-optimal because
|
||||
/// multi-summations can be more efficient than repeating individual summations (see
|
||||
/// `affine_multisummation_best`). Thus we divide digits into large chunks, and assign chunks of
|
||||
/// digits to threads. Note that there is a delicate balance here, as large chunks can result in
|
||||
/// uneven distributions of work among threads.
|
||||
const DIGITS_PER_CHUNK: usize = 80;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MsmPrecomputation<C: Curve> {
|
||||
/// For each generator (in the order they were passed to `msm_precompute`), contains a vector
|
||||
/// of powers, i.e. [(2^w)^i] for i < DIGITS.
|
||||
// TODO: Use compressed coordinates here.
|
||||
powers_per_generator: Vec<Vec<AffinePoint<C>>>,
|
||||
|
||||
/// The window size.
|
||||
w: usize,
|
||||
}
|
||||
|
||||
pub fn msm_precompute<C: Curve>(
|
||||
generators: &[ProjectivePoint<C>],
|
||||
w: usize,
|
||||
) -> MsmPrecomputation<C> {
|
||||
MsmPrecomputation {
|
||||
powers_per_generator: generators
|
||||
.into_par_iter()
|
||||
.map(|&g| precompute_single_generator(g, w))
|
||||
.collect(),
|
||||
w,
|
||||
}
|
||||
}
|
||||
|
||||
fn precompute_single_generator<C: Curve>(g: ProjectivePoint<C>, w: usize) -> Vec<AffinePoint<C>> {
|
||||
let digits = (C::ScalarField::BITS + w - 1) / w;
|
||||
let mut powers: Vec<ProjectivePoint<C>> = Vec::with_capacity(digits);
|
||||
powers.push(g);
|
||||
for i in 1..digits {
|
||||
let mut power_i_proj = powers[i - 1];
|
||||
for _j in 0..w {
|
||||
power_i_proj = power_i_proj.double();
|
||||
}
|
||||
powers.push(power_i_proj);
|
||||
}
|
||||
ProjectivePoint::batch_to_affine(&powers)
|
||||
}
|
||||
|
||||
pub fn msm_parallel<C: Curve>(
|
||||
scalars: &[C::ScalarField],
|
||||
generators: &[ProjectivePoint<C>],
|
||||
w: usize,
|
||||
) -> ProjectivePoint<C> {
|
||||
let precomputation = msm_precompute(generators, w);
|
||||
msm_execute_parallel(&precomputation, scalars)
|
||||
}
|
||||
|
||||
pub fn msm_execute<C: Curve>(
|
||||
precomputation: &MsmPrecomputation<C>,
|
||||
scalars: &[C::ScalarField],
|
||||
) -> ProjectivePoint<C> {
|
||||
assert_eq!(precomputation.powers_per_generator.len(), scalars.len());
|
||||
let w = precomputation.w;
|
||||
let digits = (C::ScalarField::BITS + w - 1) / w;
|
||||
let base = 1 << w;
|
||||
|
||||
// This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use
|
||||
// extremely large windows, the repeated scans in Yao's method could be more expensive than the
|
||||
// actual group operations. To avoid this, we store a multimap from each possible digit to the
|
||||
// positions in which that digit occurs in the scalars. These positions have the form (i, j),
|
||||
// where i is the index of the generator and j is an index into the digits of the scalar
|
||||
// associated with that generator.
|
||||
let mut digit_occurrences: Vec<Vec<(usize, usize)>> = Vec::with_capacity(digits);
|
||||
for _i in 0..base {
|
||||
digit_occurrences.push(Vec::new());
|
||||
}
|
||||
for (i, scalar) in scalars.iter().enumerate() {
|
||||
let digits = to_digits::<C>(scalar, w);
|
||||
for (j, &digit) in digits.iter().enumerate() {
|
||||
digit_occurrences[digit].push((i, j));
|
||||
}
|
||||
}
|
||||
|
||||
let mut y = ProjectivePoint::ZERO;
|
||||
let mut u = ProjectivePoint::ZERO;
|
||||
|
||||
for digit in (1..base).rev() {
|
||||
for &(i, j) in &digit_occurrences[digit] {
|
||||
u = u + precomputation.powers_per_generator[i][j];
|
||||
}
|
||||
y = y + u;
|
||||
}
|
||||
|
||||
y
|
||||
}
|
||||
|
||||
pub fn msm_execute_parallel<C: Curve>(
|
||||
precomputation: &MsmPrecomputation<C>,
|
||||
scalars: &[C::ScalarField],
|
||||
) -> ProjectivePoint<C> {
|
||||
assert_eq!(precomputation.powers_per_generator.len(), scalars.len());
|
||||
let w = precomputation.w;
|
||||
let digits = (C::ScalarField::BITS + w - 1) / w;
|
||||
let base = 1 << w;
|
||||
|
||||
// This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use
|
||||
// extremely large windows, the repeated scans in Yao's method could be more expensive than the
|
||||
// actual group operations. To avoid this, we store a multimap from each possible digit to the
|
||||
// positions in which that digit occurs in the scalars. These positions have the form (i, j),
|
||||
// where i is the index of the generator and j is an index into the digits of the scalar
|
||||
// associated with that generator.
|
||||
let mut digit_occurrences: Vec<Vec<(usize, usize)>> = Vec::with_capacity(digits);
|
||||
for _i in 0..base {
|
||||
digit_occurrences.push(Vec::new());
|
||||
}
|
||||
for (i, scalar) in scalars.iter().enumerate() {
|
||||
let digits = to_digits::<C>(scalar, w);
|
||||
for (j, &digit) in digits.iter().enumerate() {
|
||||
digit_occurrences[digit].push((i, j));
|
||||
}
|
||||
}
|
||||
|
||||
// For each digit, we add up the powers associated with all occurrences that digit.
|
||||
let digits: Vec<usize> = (0..base).collect();
|
||||
let digit_acc: Vec<ProjectivePoint<C>> = digits
|
||||
.par_chunks(DIGITS_PER_CHUNK)
|
||||
.flat_map(|chunk| {
|
||||
let summations: Vec<Vec<AffinePoint<C>>> = chunk
|
||||
.iter()
|
||||
.map(|&digit| {
|
||||
digit_occurrences[digit]
|
||||
.iter()
|
||||
.map(|&(i, j)| precomputation.powers_per_generator[i][j])
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
affine_multisummation_best(summations)
|
||||
})
|
||||
.collect();
|
||||
// println!("Computing the per-digit summations (in parallel) took {}s", start.elapsed().as_secs_f64());
|
||||
|
||||
let mut y = ProjectivePoint::ZERO;
|
||||
let mut u = ProjectivePoint::ZERO;
|
||||
for digit in (1..base).rev() {
|
||||
u = u + digit_acc[digit];
|
||||
y = y + u;
|
||||
}
|
||||
// println!("Final summation (sequential) {}s", start.elapsed().as_secs_f64());
|
||||
y
|
||||
}
|
||||
|
||||
pub(crate) fn to_digits<C: Curve>(x: &C::ScalarField, w: usize) -> Vec<usize> {
|
||||
let scalar_bits = C::ScalarField::BITS;
|
||||
let num_digits = (scalar_bits + w - 1) / w;
|
||||
|
||||
// Convert x to a bool array.
|
||||
let x_canonical: Vec<_> = x
|
||||
.to_biguint()
|
||||
.to_u64_digits()
|
||||
.iter()
|
||||
.cloned()
|
||||
.pad_using(scalar_bits / 64, |_| 0)
|
||||
.collect();
|
||||
let mut x_bits = Vec::with_capacity(scalar_bits);
|
||||
for i in 0..scalar_bits {
|
||||
x_bits.push((x_canonical[i / 64] >> (i as u64 % 64) & 1) != 0);
|
||||
}
|
||||
|
||||
let mut digits = Vec::with_capacity(num_digits);
|
||||
for i in 0..num_digits {
|
||||
let mut digit = 0;
|
||||
for j in ((i * w)..((i + 1) * w).min(scalar_bits)).rev() {
|
||||
digit <<= 1;
|
||||
digit |= x_bits[j] as usize;
|
||||
}
|
||||
digits.push(digit);
|
||||
}
|
||||
digits
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use num::BigUint;
|
||||
|
||||
use crate::curve::curve_msm::{msm_execute, msm_precompute, to_digits};
|
||||
use crate::curve::curve_types::Curve;
|
||||
use crate::curve::secp256k1::Secp256K1;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::secp256k1_scalar::Secp256K1Scalar;
|
||||
|
||||
#[test]
|
||||
fn test_to_digits() {
|
||||
let x_canonical = [
|
||||
0b10101010101010101010101010101010,
|
||||
0b10101010101010101010101010101010,
|
||||
0b11001100110011001100110011001100,
|
||||
0b11001100110011001100110011001100,
|
||||
0b11110000111100001111000011110000,
|
||||
0b11110000111100001111000011110000,
|
||||
0b00001111111111111111111111111111,
|
||||
0b11111111111111111111111111111111,
|
||||
];
|
||||
let x = Secp256K1Scalar::from_biguint(BigUint::from_slice(&x_canonical));
|
||||
assert_eq!(x.to_biguint().to_u32_digits(), x_canonical);
|
||||
assert_eq!(
|
||||
to_digits::<Secp256K1>(&x, 17),
|
||||
vec![
|
||||
0b01010101010101010,
|
||||
0b10101010101010101,
|
||||
0b01010101010101010,
|
||||
0b11001010101010101,
|
||||
0b01100110011001100,
|
||||
0b00110011001100110,
|
||||
0b10011001100110011,
|
||||
0b11110000110011001,
|
||||
0b01111000011110000,
|
||||
0b00111100001111000,
|
||||
0b00011110000111100,
|
||||
0b11111111111111110,
|
||||
0b01111111111111111,
|
||||
0b11111111111111000,
|
||||
0b11111111111111111,
|
||||
0b1,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_msm() {
|
||||
let w = 5;
|
||||
|
||||
let generator_1 = Secp256K1::GENERATOR_PROJECTIVE;
|
||||
let generator_2 = generator_1 + generator_1;
|
||||
let generator_3 = generator_1 + generator_2;
|
||||
|
||||
let scalar_1 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[
|
||||
11111111, 22222222, 33333333, 44444444,
|
||||
]));
|
||||
let scalar_2 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[
|
||||
22222222, 22222222, 33333333, 44444444,
|
||||
]));
|
||||
let scalar_3 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[
|
||||
33333333, 22222222, 33333333, 44444444,
|
||||
]));
|
||||
|
||||
let generators = vec![generator_1, generator_2, generator_3];
|
||||
let scalars = vec![scalar_1, scalar_2, scalar_3];
|
||||
|
||||
let precomputation = msm_precompute(&generators, w);
|
||||
let result_msm = msm_execute(&precomputation, &scalars);
|
||||
|
||||
let result_naive = Secp256K1::convert(scalar_1) * generator_1
|
||||
+ Secp256K1::convert(scalar_2) * generator_2
|
||||
+ Secp256K1::convert(scalar_3) * generator_3;
|
||||
|
||||
assert_eq!(result_msm, result_naive);
|
||||
}
|
||||
}
|
||||
97
src/curve/curve_multiplication.rs
Normal file
97
src/curve/curve_multiplication.rs
Normal file
@ -0,0 +1,97 @@
|
||||
use std::ops::Mul;
|
||||
|
||||
use crate::curve::curve_types::{Curve, CurveScalar, ProjectivePoint};
|
||||
use crate::field::field_types::Field;
|
||||
|
||||
const WINDOW_BITS: usize = 4;
|
||||
const BASE: usize = 1 << WINDOW_BITS;
|
||||
|
||||
fn digits_per_scalar<C: Curve>() -> usize {
|
||||
(C::ScalarField::BITS + WINDOW_BITS - 1) / WINDOW_BITS
|
||||
}
|
||||
|
||||
/// Precomputed state used for scalar x ProjectivePoint multiplications,
|
||||
/// specific to a particular generator.
|
||||
#[derive(Clone)]
|
||||
pub struct MultiplicationPrecomputation<C: Curve> {
|
||||
/// [(2^w)^i] g for each i < digits_per_scalar.
|
||||
powers: Vec<ProjectivePoint<C>>,
|
||||
}
|
||||
|
||||
impl<C: Curve> ProjectivePoint<C> {
|
||||
pub fn mul_precompute(&self) -> MultiplicationPrecomputation<C> {
|
||||
let num_digits = digits_per_scalar::<C>();
|
||||
let mut powers = Vec::with_capacity(num_digits);
|
||||
powers.push(*self);
|
||||
for i in 1..num_digits {
|
||||
let mut power_i = powers[i - 1];
|
||||
for _j in 0..WINDOW_BITS {
|
||||
power_i = power_i.double();
|
||||
}
|
||||
powers.push(power_i);
|
||||
}
|
||||
|
||||
MultiplicationPrecomputation { powers }
|
||||
}
|
||||
|
||||
pub fn mul_with_precomputation(
|
||||
&self,
|
||||
scalar: C::ScalarField,
|
||||
precomputation: MultiplicationPrecomputation<C>,
|
||||
) -> Self {
|
||||
// Yao's method; see https://koclab.cs.ucsb.edu/teaching/ecc/eccPapers/Doche-ch09.pdf
|
||||
let precomputed_powers = precomputation.powers;
|
||||
|
||||
let digits = to_digits::<C>(&scalar);
|
||||
|
||||
let mut y = ProjectivePoint::ZERO;
|
||||
let mut u = ProjectivePoint::ZERO;
|
||||
let mut all_summands = Vec::new();
|
||||
for j in (1..BASE).rev() {
|
||||
let mut u_summands = Vec::new();
|
||||
for (i, &digit) in digits.iter().enumerate() {
|
||||
if digit == j as u64 {
|
||||
u_summands.push(precomputed_powers[i]);
|
||||
}
|
||||
}
|
||||
all_summands.push(u_summands);
|
||||
}
|
||||
|
||||
let all_sums: Vec<ProjectivePoint<C>> = all_summands
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|vec| vec.iter().fold(ProjectivePoint::ZERO, |a, &b| a + b))
|
||||
.collect();
|
||||
for i in 0..all_sums.len() {
|
||||
u = u + all_sums[i];
|
||||
y = y + u;
|
||||
}
|
||||
y
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Curve> Mul<ProjectivePoint<C>> for CurveScalar<C> {
|
||||
type Output = ProjectivePoint<C>;
|
||||
|
||||
fn mul(self, rhs: ProjectivePoint<C>) -> Self::Output {
|
||||
let precomputation = rhs.mul_precompute();
|
||||
rhs.mul_with_precomputation(self.0, precomputation)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::assertions_on_constants)]
|
||||
fn to_digits<C: Curve>(x: &C::ScalarField) -> Vec<u64> {
|
||||
debug_assert!(
|
||||
64 % WINDOW_BITS == 0,
|
||||
"For simplicity, only power-of-two window sizes are handled for now"
|
||||
);
|
||||
let digits_per_u64 = 64 / WINDOW_BITS;
|
||||
let mut digits = Vec::with_capacity(digits_per_scalar::<C>());
|
||||
for limb in x.to_biguint().to_u64_digits() {
|
||||
for j in 0..digits_per_u64 {
|
||||
digits.push((limb >> (j * WINDOW_BITS) as u64) % BASE as u64);
|
||||
}
|
||||
}
|
||||
|
||||
digits
|
||||
}
|
||||
237
src/curve/curve_summation.rs
Normal file
237
src/curve/curve_summation.rs
Normal file
@ -0,0 +1,237 @@
|
||||
use std::iter::Sum;
|
||||
|
||||
use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint};
|
||||
use crate::field::field_types::Field;
|
||||
|
||||
impl<C: Curve> Sum<AffinePoint<C>> for ProjectivePoint<C> {
|
||||
fn sum<I: Iterator<Item = AffinePoint<C>>>(iter: I) -> ProjectivePoint<C> {
|
||||
let points: Vec<_> = iter.collect();
|
||||
affine_summation_best(points)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Curve> Sum for ProjectivePoint<C> {
|
||||
fn sum<I: Iterator<Item = ProjectivePoint<C>>>(iter: I) -> ProjectivePoint<C> {
|
||||
iter.fold(ProjectivePoint::ZERO, |acc, x| acc + x)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn affine_summation_best<C: Curve>(summation: Vec<AffinePoint<C>>) -> ProjectivePoint<C> {
|
||||
let result = affine_multisummation_best(vec![summation]);
|
||||
debug_assert_eq!(result.len(), 1);
|
||||
result[0]
|
||||
}
|
||||
|
||||
pub fn affine_multisummation_best<C: Curve>(
|
||||
summations: Vec<Vec<AffinePoint<C>>>,
|
||||
) -> Vec<ProjectivePoint<C>> {
|
||||
let pairwise_sums: usize = summations.iter().map(|summation| summation.len() / 2).sum();
|
||||
|
||||
// This threshold is chosen based on data from the summation benchmarks.
|
||||
if pairwise_sums < 70 {
|
||||
affine_multisummation_pairwise(summations)
|
||||
} else {
|
||||
affine_multisummation_batch_inversion(summations)
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds each pair of points using an affine + affine = projective formula, then adds up the
|
||||
/// intermediate sums using a projective formula.
|
||||
pub fn affine_multisummation_pairwise<C: Curve>(
|
||||
summations: Vec<Vec<AffinePoint<C>>>,
|
||||
) -> Vec<ProjectivePoint<C>> {
|
||||
summations
|
||||
.into_iter()
|
||||
.map(affine_summation_pairwise)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Adds each pair of points using an affine + affine = projective formula, then adds up the
|
||||
/// intermediate sums using a projective formula.
|
||||
pub fn affine_summation_pairwise<C: Curve>(points: Vec<AffinePoint<C>>) -> ProjectivePoint<C> {
|
||||
let mut reduced_points: Vec<ProjectivePoint<C>> = Vec::new();
|
||||
for chunk in points.chunks(2) {
|
||||
match chunk.len() {
|
||||
1 => reduced_points.push(chunk[0].to_projective()),
|
||||
2 => reduced_points.push(chunk[0] + chunk[1]),
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
// TODO: Avoid copying (deref)
|
||||
reduced_points
|
||||
.iter()
|
||||
.fold(ProjectivePoint::ZERO, |sum, x| sum + *x)
|
||||
}
|
||||
|
||||
/// Computes several summations of affine points by applying an affine group law, except that the
|
||||
/// divisions are batched via Montgomery's trick.
|
||||
pub fn affine_summation_batch_inversion<C: Curve>(
|
||||
summation: Vec<AffinePoint<C>>,
|
||||
) -> ProjectivePoint<C> {
|
||||
let result = affine_multisummation_batch_inversion(vec![summation]);
|
||||
debug_assert_eq!(result.len(), 1);
|
||||
result[0]
|
||||
}
|
||||
|
||||
/// Computes several summations of affine points by applying an affine group law, except that the
|
||||
/// divisions are batched via Montgomery's trick.
|
||||
pub fn affine_multisummation_batch_inversion<C: Curve>(
|
||||
summations: Vec<Vec<AffinePoint<C>>>,
|
||||
) -> Vec<ProjectivePoint<C>> {
|
||||
let mut elements_to_invert = Vec::new();
|
||||
|
||||
// For each pair of points, (x1, y1) and (x2, y2), that we're going to add later, we want to
|
||||
// invert either y (if the points are equal) or x1 - x2 (otherwise). We will use these later.
|
||||
for summation in &summations {
|
||||
let n = summation.len();
|
||||
// The special case for n=0 is to avoid underflow.
|
||||
let range_end = if n == 0 { 0 } else { n - 1 };
|
||||
|
||||
for i in (0..range_end).step_by(2) {
|
||||
let p1 = summation[i];
|
||||
let p2 = summation[i + 1];
|
||||
let AffinePoint {
|
||||
x: x1,
|
||||
y: y1,
|
||||
zero: zero1,
|
||||
} = p1;
|
||||
let AffinePoint {
|
||||
x: x2,
|
||||
y: _y2,
|
||||
zero: zero2,
|
||||
} = p2;
|
||||
|
||||
if zero1 || zero2 || p1 == -p2 {
|
||||
// These are trivial cases where we won't need any inverse.
|
||||
} else if p1 == p2 {
|
||||
elements_to_invert.push(y1.double());
|
||||
} else {
|
||||
elements_to_invert.push(x1 - x2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let inverses: Vec<C::BaseField> =
|
||||
C::BaseField::batch_multiplicative_inverse(&elements_to_invert);
|
||||
|
||||
let mut all_reduced_points = Vec::with_capacity(summations.len());
|
||||
let mut inverse_index = 0;
|
||||
for summation in summations {
|
||||
let n = summation.len();
|
||||
let mut reduced_points = Vec::with_capacity((n + 1) / 2);
|
||||
|
||||
// The special case for n=0 is to avoid underflow.
|
||||
let range_end = if n == 0 { 0 } else { n - 1 };
|
||||
|
||||
for i in (0..range_end).step_by(2) {
|
||||
let p1 = summation[i];
|
||||
let p2 = summation[i + 1];
|
||||
let AffinePoint {
|
||||
x: x1,
|
||||
y: y1,
|
||||
zero: zero1,
|
||||
} = p1;
|
||||
let AffinePoint {
|
||||
x: x2,
|
||||
y: y2,
|
||||
zero: zero2,
|
||||
} = p2;
|
||||
|
||||
let sum = if zero1 {
|
||||
p2
|
||||
} else if zero2 {
|
||||
p1
|
||||
} else if p1 == -p2 {
|
||||
AffinePoint::ZERO
|
||||
} else {
|
||||
// It's a non-trivial case where we need one of the inverses we computed earlier.
|
||||
let inverse = inverses[inverse_index];
|
||||
inverse_index += 1;
|
||||
|
||||
if p1 == p2 {
|
||||
// This is the doubling case.
|
||||
let mut numerator = x1.square().triple();
|
||||
if C::A.is_nonzero() {
|
||||
numerator += C::A;
|
||||
}
|
||||
let quotient = numerator * inverse;
|
||||
let x3 = quotient.square() - x1.double();
|
||||
let y3 = quotient * (x1 - x3) - y1;
|
||||
AffinePoint::nonzero(x3, y3)
|
||||
} else {
|
||||
// This is the general case. We use the incomplete addition formulas 4.3 and 4.4.
|
||||
let quotient = (y1 - y2) * inverse;
|
||||
let x3 = quotient.square() - x1 - x2;
|
||||
let y3 = quotient * (x1 - x3) - y1;
|
||||
AffinePoint::nonzero(x3, y3)
|
||||
}
|
||||
};
|
||||
reduced_points.push(sum);
|
||||
}
|
||||
|
||||
// If n is odd, the last point was not part of a pair.
|
||||
if n % 2 == 1 {
|
||||
reduced_points.push(summation[n - 1]);
|
||||
}
|
||||
|
||||
all_reduced_points.push(reduced_points);
|
||||
}
|
||||
|
||||
// We should have consumed all of the inverses from the batch computation.
|
||||
debug_assert_eq!(inverse_index, inverses.len());
|
||||
|
||||
// Recurse with our smaller set of points.
|
||||
affine_multisummation_best(all_reduced_points)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::curve::curve_summation::{
|
||||
affine_summation_batch_inversion, affine_summation_pairwise,
|
||||
};
|
||||
use crate::curve::curve_types::{Curve, ProjectivePoint};
|
||||
use crate::curve::secp256k1::Secp256K1;
|
||||
|
||||
#[test]
|
||||
fn test_pairwise_affine_summation() {
|
||||
let g_affine = Secp256K1::GENERATOR_AFFINE;
|
||||
let g2_affine = (g_affine + g_affine).to_affine();
|
||||
let g3_affine = (g_affine + g_affine + g_affine).to_affine();
|
||||
let g2_proj = g2_affine.to_projective();
|
||||
let g3_proj = g3_affine.to_projective();
|
||||
assert_eq!(
|
||||
affine_summation_pairwise::<Secp256K1>(vec![g_affine, g_affine]),
|
||||
g2_proj
|
||||
);
|
||||
assert_eq!(
|
||||
affine_summation_pairwise::<Secp256K1>(vec![g_affine, g2_affine]),
|
||||
g3_proj
|
||||
);
|
||||
assert_eq!(
|
||||
affine_summation_pairwise::<Secp256K1>(vec![g_affine, g_affine, g_affine]),
|
||||
g3_proj
|
||||
);
|
||||
assert_eq!(
|
||||
affine_summation_pairwise::<Secp256K1>(vec![]),
|
||||
ProjectivePoint::ZERO
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pairwise_affine_summation_batch_inversion() {
|
||||
let g = Secp256K1::GENERATOR_AFFINE;
|
||||
let g_proj = g.to_projective();
|
||||
assert_eq!(
|
||||
affine_summation_batch_inversion::<Secp256K1>(vec![g, g]),
|
||||
g_proj + g_proj
|
||||
);
|
||||
assert_eq!(
|
||||
affine_summation_batch_inversion::<Secp256K1>(vec![g, g, g]),
|
||||
g_proj + g_proj + g_proj
|
||||
);
|
||||
assert_eq!(
|
||||
affine_summation_batch_inversion::<Secp256K1>(vec![]),
|
||||
ProjectivePoint::ZERO
|
||||
);
|
||||
}
|
||||
}
|
||||
260
src/curve/curve_types.rs
Normal file
260
src/curve/curve_types.rs
Normal file
@ -0,0 +1,260 @@
|
||||
use std::fmt::Debug;
|
||||
use std::ops::Neg;
|
||||
|
||||
use crate::field::field_types::Field;
|
||||
|
||||
// To avoid implementation conflicts from associated types,
|
||||
// see https://github.com/rust-lang/rust/issues/20400
|
||||
pub struct CurveScalar<C: Curve>(pub <C as Curve>::ScalarField);
|
||||
|
||||
/// A short Weierstrass curve.
|
||||
pub trait Curve: 'static + Sync + Sized + Copy + Debug {
|
||||
type BaseField: Field;
|
||||
type ScalarField: Field;
|
||||
|
||||
const A: Self::BaseField;
|
||||
const B: Self::BaseField;
|
||||
|
||||
const GENERATOR_AFFINE: AffinePoint<Self>;
|
||||
|
||||
const GENERATOR_PROJECTIVE: ProjectivePoint<Self> = ProjectivePoint {
|
||||
x: Self::GENERATOR_AFFINE.x,
|
||||
y: Self::GENERATOR_AFFINE.y,
|
||||
z: Self::BaseField::ONE,
|
||||
};
|
||||
|
||||
fn convert(x: Self::ScalarField) -> CurveScalar<Self> {
|
||||
CurveScalar(x)
|
||||
}
|
||||
|
||||
fn is_safe_curve() -> bool {
|
||||
// Added additional check to prevent using vulnerabilties in case a discriminant is equal to 0.
|
||||
(Self::A.cube().double().double() + Self::B.square().triple().triple().triple())
|
||||
.is_nonzero()
|
||||
}
|
||||
}
|
||||
|
||||
/// A point on a short Weierstrass curve, represented in affine coordinates.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct AffinePoint<C: Curve> {
|
||||
pub x: C::BaseField,
|
||||
pub y: C::BaseField,
|
||||
pub zero: bool,
|
||||
}
|
||||
|
||||
impl<C: Curve> AffinePoint<C> {
|
||||
pub const ZERO: Self = Self {
|
||||
x: C::BaseField::ZERO,
|
||||
y: C::BaseField::ZERO,
|
||||
zero: true,
|
||||
};
|
||||
|
||||
pub fn nonzero(x: C::BaseField, y: C::BaseField) -> Self {
|
||||
let point = Self { x, y, zero: false };
|
||||
debug_assert!(point.is_valid());
|
||||
point
|
||||
}
|
||||
|
||||
pub fn is_valid(&self) -> bool {
|
||||
let Self { x, y, zero } = *self;
|
||||
zero || y.square() == x.cube() + C::A * x + C::B
|
||||
}
|
||||
|
||||
pub fn to_projective(&self) -> ProjectivePoint<C> {
|
||||
let Self { x, y, zero } = *self;
|
||||
let z = if zero {
|
||||
C::BaseField::ZERO
|
||||
} else {
|
||||
C::BaseField::ONE
|
||||
};
|
||||
|
||||
ProjectivePoint { x, y, z }
|
||||
}
|
||||
|
||||
pub fn batch_to_projective(affine_points: &[Self]) -> Vec<ProjectivePoint<C>> {
|
||||
affine_points.iter().map(Self::to_projective).collect()
|
||||
}
|
||||
|
||||
pub fn double(&self) -> Self {
|
||||
let AffinePoint { x: x1, y: y1, zero } = *self;
|
||||
|
||||
if zero {
|
||||
return AffinePoint::ZERO;
|
||||
}
|
||||
|
||||
let double_y = y1.double();
|
||||
let inv_double_y = double_y.inverse(); // (2y)^(-1)
|
||||
let triple_xx = x1.square().triple(); // 3x^2
|
||||
let lambda = (triple_xx + C::A) * inv_double_y;
|
||||
let x3 = lambda.square() - self.x.double();
|
||||
let y3 = lambda * (x1 - x3) - y1;
|
||||
|
||||
Self {
|
||||
x: x3,
|
||||
y: y3,
|
||||
zero: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Curve> PartialEq for AffinePoint<C> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
let AffinePoint {
|
||||
x: x1,
|
||||
y: y1,
|
||||
zero: zero1,
|
||||
} = *self;
|
||||
let AffinePoint {
|
||||
x: x2,
|
||||
y: y2,
|
||||
zero: zero2,
|
||||
} = *other;
|
||||
if zero1 || zero2 {
|
||||
return zero1 == zero2;
|
||||
}
|
||||
x1 == x2 && y1 == y2
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Curve> Eq for AffinePoint<C> {}
|
||||
|
||||
/// A point on a short Weierstrass curve, represented in projective coordinates.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct ProjectivePoint<C: Curve> {
|
||||
pub x: C::BaseField,
|
||||
pub y: C::BaseField,
|
||||
pub z: C::BaseField,
|
||||
}
|
||||
|
||||
impl<C: Curve> ProjectivePoint<C> {
|
||||
pub const ZERO: Self = Self {
|
||||
x: C::BaseField::ZERO,
|
||||
y: C::BaseField::ONE,
|
||||
z: C::BaseField::ZERO,
|
||||
};
|
||||
|
||||
pub fn nonzero(x: C::BaseField, y: C::BaseField, z: C::BaseField) -> Self {
|
||||
let point = Self { x, y, z };
|
||||
debug_assert!(point.is_valid());
|
||||
point
|
||||
}
|
||||
|
||||
pub fn is_valid(&self) -> bool {
|
||||
let Self { x, y, z } = *self;
|
||||
z.is_zero() || y.square() * z == x.cube() + C::A * x * z.square() + C::B * z.cube()
|
||||
}
|
||||
|
||||
pub fn to_affine(&self) -> AffinePoint<C> {
|
||||
let Self { x, y, z } = *self;
|
||||
if z == C::BaseField::ZERO {
|
||||
AffinePoint::ZERO
|
||||
} else {
|
||||
let z_inv = z.inverse();
|
||||
AffinePoint::nonzero(x * z_inv, y * z_inv)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn batch_to_affine(proj_points: &[Self]) -> Vec<AffinePoint<C>> {
|
||||
let n = proj_points.len();
|
||||
let zs: Vec<C::BaseField> = proj_points.iter().map(|pp| pp.z).collect();
|
||||
let z_invs = C::BaseField::batch_multiplicative_inverse(&zs);
|
||||
|
||||
let mut result = Vec::with_capacity(n);
|
||||
for i in 0..n {
|
||||
let Self { x, y, z } = proj_points[i];
|
||||
result.push(if z == C::BaseField::ZERO {
|
||||
AffinePoint::ZERO
|
||||
} else {
|
||||
let z_inv = z_invs[i];
|
||||
AffinePoint::nonzero(x * z_inv, y * z_inv)
|
||||
});
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
// From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/doubling/dbl-2007-bl
|
||||
pub fn double(&self) -> Self {
|
||||
let Self { x, y, z } = *self;
|
||||
if z == C::BaseField::ZERO {
|
||||
return ProjectivePoint::ZERO;
|
||||
}
|
||||
|
||||
let xx = x.square();
|
||||
let zz = z.square();
|
||||
let mut w = xx.triple();
|
||||
if C::A.is_nonzero() {
|
||||
w += C::A * zz;
|
||||
}
|
||||
let s = y.double() * z;
|
||||
let r = y * s;
|
||||
let rr = r.square();
|
||||
let b = (x + r).square() - (xx + rr);
|
||||
let h = w.square() - b.double();
|
||||
let x3 = h * s;
|
||||
let y3 = w * (b - h) - rr.double();
|
||||
let z3 = s.cube();
|
||||
Self {
|
||||
x: x3,
|
||||
y: y3,
|
||||
z: z3,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_slices(a: &[Self], b: &[Self]) -> Vec<Self> {
|
||||
assert_eq!(a.len(), b.len());
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(&a_i, &b_i)| a_i + b_i)
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn neg(&self) -> Self {
|
||||
Self {
|
||||
x: self.x,
|
||||
y: -self.y,
|
||||
z: self.z,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Curve> PartialEq for ProjectivePoint<C> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
let ProjectivePoint {
|
||||
x: x1,
|
||||
y: y1,
|
||||
z: z1,
|
||||
} = *self;
|
||||
let ProjectivePoint {
|
||||
x: x2,
|
||||
y: y2,
|
||||
z: z2,
|
||||
} = *other;
|
||||
if z1 == C::BaseField::ZERO || z2 == C::BaseField::ZERO {
|
||||
return z1 == z2;
|
||||
}
|
||||
|
||||
// We want to compare (x1/z1, y1/z1) == (x2/z2, y2/z2).
|
||||
// But to avoid field division, it is better to compare (x1*z2, y1*z2) == (x2*z1, y2*z1).
|
||||
x1 * z2 == x2 * z1 && y1 * z2 == y2 * z1
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Curve> Eq for ProjectivePoint<C> {}
|
||||
|
||||
impl<C: Curve> Neg for AffinePoint<C> {
|
||||
type Output = AffinePoint<C>;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
let AffinePoint { x, y, zero } = self;
|
||||
AffinePoint { x, y: -y, zero }
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Curve> Neg for ProjectivePoint<C> {
|
||||
type Output = ProjectivePoint<C>;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
let ProjectivePoint { x, y, z } = self;
|
||||
ProjectivePoint { x, y: -y, z }
|
||||
}
|
||||
}
|
||||
6
src/curve/mod.rs
Normal file
6
src/curve/mod.rs
Normal file
@ -0,0 +1,6 @@
|
||||
pub mod curve_adds;
|
||||
pub mod curve_msm;
|
||||
pub mod curve_multiplication;
|
||||
pub mod curve_summation;
|
||||
pub mod curve_types;
|
||||
pub mod secp256k1;
|
||||
98
src/curve/secp256k1.rs
Normal file
98
src/curve/secp256k1.rs
Normal file
@ -0,0 +1,98 @@
|
||||
use crate::curve::curve_types::{AffinePoint, Curve};
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::secp256k1_base::Secp256K1Base;
|
||||
use crate::field::secp256k1_scalar::Secp256K1Scalar;
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct Secp256K1;
|
||||
|
||||
impl Curve for Secp256K1 {
|
||||
type BaseField = Secp256K1Base;
|
||||
type ScalarField = Secp256K1Scalar;
|
||||
|
||||
const A: Secp256K1Base = Secp256K1Base::ZERO;
|
||||
const B: Secp256K1Base = Secp256K1Base([7, 0, 0, 0]);
|
||||
const GENERATOR_AFFINE: AffinePoint<Self> = AffinePoint {
|
||||
x: SECP256K1_GENERATOR_X,
|
||||
y: SECP256K1_GENERATOR_Y,
|
||||
zero: false,
|
||||
};
|
||||
}
|
||||
|
||||
// 55066263022277343669578718895168534326250603453777594175500187360389116729240
|
||||
const SECP256K1_GENERATOR_X: Secp256K1Base = Secp256K1Base([
|
||||
0x59F2815B16F81798,
|
||||
0x029BFCDB2DCE28D9,
|
||||
0x55A06295CE870B07,
|
||||
0x79BE667EF9DCBBAC,
|
||||
]);
|
||||
|
||||
/// 32670510020758816978083085130507043184471273380659243275938904335757337482424
|
||||
const SECP256K1_GENERATOR_Y: Secp256K1Base = Secp256K1Base([
|
||||
0x9C47D08FFB10D4B8,
|
||||
0xFD17B448A6855419,
|
||||
0x5DA4FBFC0E1108A8,
|
||||
0x483ADA7726A3C465,
|
||||
]);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use num::BigUint;
|
||||
|
||||
use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint};
|
||||
use crate::curve::secp256k1::Secp256K1;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::secp256k1_scalar::Secp256K1Scalar;
|
||||
|
||||
#[test]
|
||||
fn test_generator() {
|
||||
let g = Secp256K1::GENERATOR_AFFINE;
|
||||
assert!(g.is_valid());
|
||||
|
||||
let neg_g = AffinePoint::<Secp256K1> {
|
||||
x: g.x,
|
||||
y: -g.y,
|
||||
zero: g.zero,
|
||||
};
|
||||
assert!(neg_g.is_valid());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_naive_multiplication() {
|
||||
let g = Secp256K1::GENERATOR_PROJECTIVE;
|
||||
let ten = Secp256K1Scalar::from_canonical_u64(10);
|
||||
let product = mul_naive(ten, g);
|
||||
let sum = g + g + g + g + g + g + g + g + g + g;
|
||||
assert_eq!(product, sum);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_g1_multiplication() {
|
||||
let lhs = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[
|
||||
1111, 2222, 3333, 4444, 5555, 6666, 7777, 8888,
|
||||
]));
|
||||
assert_eq!(
|
||||
Secp256K1::convert(lhs) * Secp256K1::GENERATOR_PROJECTIVE,
|
||||
mul_naive(lhs, Secp256K1::GENERATOR_PROJECTIVE)
|
||||
);
|
||||
}
|
||||
|
||||
/// A simple, somewhat inefficient implementation of multiplication which is used as a reference
|
||||
/// for correctness.
|
||||
fn mul_naive(
|
||||
lhs: Secp256K1Scalar,
|
||||
rhs: ProjectivePoint<Secp256K1>,
|
||||
) -> ProjectivePoint<Secp256K1> {
|
||||
let mut g = rhs;
|
||||
let mut sum = ProjectivePoint::ZERO;
|
||||
for limb in lhs.to_biguint().to_u64_digits().iter() {
|
||||
for j in 0..64 {
|
||||
if (limb >> j & 1u64) != 0u64 {
|
||||
sum = sum + g;
|
||||
}
|
||||
g = g.double();
|
||||
}
|
||||
}
|
||||
sum
|
||||
}
|
||||
}
|
||||
@ -31,8 +31,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn distinct_cosets() {
|
||||
// TODO: Switch to a smaller test field so that collision rejection is likely to occur.
|
||||
|
||||
type F = GoldilocksField;
|
||||
const SUBGROUP_BITS: usize = 5;
|
||||
const NUM_SHIFTS: usize = 50;
|
||||
|
||||
@ -160,12 +160,32 @@ impl<F: OEF<D>, const D: usize> PolynomialCoeffsAlgebra<F, D> {
|
||||
.fold(ExtensionAlgebra::ZERO, |acc, &c| acc * x + c)
|
||||
}
|
||||
|
||||
/// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1.
|
||||
pub fn eval_with_powers(&self, powers: &[ExtensionAlgebra<F, D>]) -> ExtensionAlgebra<F, D> {
|
||||
debug_assert_eq!(self.coeffs.len(), powers.len() + 1);
|
||||
let acc = self.coeffs[0];
|
||||
self.coeffs[1..]
|
||||
.iter()
|
||||
.zip(powers)
|
||||
.fold(acc, |acc, (&x, &c)| acc + c * x)
|
||||
}
|
||||
|
||||
pub fn eval_base(&self, x: F) -> ExtensionAlgebra<F, D> {
|
||||
self.coeffs
|
||||
.iter()
|
||||
.rev()
|
||||
.fold(ExtensionAlgebra::ZERO, |acc, &c| acc.scalar_mul(x) + c)
|
||||
}
|
||||
|
||||
/// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1.
|
||||
pub fn eval_base_with_powers(&self, powers: &[F]) -> ExtensionAlgebra<F, D> {
|
||||
debug_assert_eq!(self.coeffs.len(), powers.len() + 1);
|
||||
let acc = self.coeffs[0];
|
||||
self.coeffs[1..]
|
||||
.iter()
|
||||
.zip(powers)
|
||||
.fold(acc, |acc, (&x, &c)| acc + x.scalar_mul(c))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
use crate::field::field_types::{Field, PrimeField};
|
||||
use std::convert::TryInto;
|
||||
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
|
||||
@ -3,6 +3,7 @@ use std::iter::{Product, Sum};
|
||||
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||
|
||||
use num::bigint::BigUint;
|
||||
use num::Integer;
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@ -49,26 +50,28 @@ impl<F: Extendable<2>> From<F> for QuadraticExtension<F> {
|
||||
}
|
||||
|
||||
impl<F: Extendable<2>> Field for QuadraticExtension<F> {
|
||||
type PrimeField = F;
|
||||
|
||||
const ZERO: Self = Self([F::ZERO; 2]);
|
||||
const ONE: Self = Self([F::ONE, F::ZERO]);
|
||||
const TWO: Self = Self([F::TWO, F::ZERO]);
|
||||
const NEG_ONE: Self = Self([F::NEG_ONE, F::ZERO]);
|
||||
|
||||
const CHARACTERISTIC: u64 = F::CHARACTERISTIC;
|
||||
|
||||
// `p^2 - 1 = (p - 1)(p + 1)`. The `p - 1` term has a two-adicity of `F::TWO_ADICITY`. As
|
||||
// long as `F::TWO_ADICITY >= 2`, `p` can be written as `4n + 1`, so `p + 1` can be written as
|
||||
// `2(2n + 1)`, which has a 2-adicity of 1.
|
||||
const TWO_ADICITY: usize = F::TWO_ADICITY + 1;
|
||||
const CHARACTERISTIC_TWO_ADICITY: usize = F::CHARACTERISTIC_TWO_ADICITY;
|
||||
|
||||
const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR);
|
||||
const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR);
|
||||
|
||||
const BITS: usize = F::BITS * 2;
|
||||
|
||||
fn order() -> BigUint {
|
||||
F::order() * F::order()
|
||||
}
|
||||
fn characteristic() -> BigUint {
|
||||
F::characteristic()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn square(&self) -> Self {
|
||||
@ -99,6 +102,15 @@ impl<F: Extendable<2>> Field for QuadraticExtension<F> {
|
||||
))
|
||||
}
|
||||
|
||||
fn from_biguint(n: BigUint) -> Self {
|
||||
let (high, low) = n.div_rem(&F::order());
|
||||
Self([F::from_biguint(low), F::from_biguint(high)])
|
||||
}
|
||||
|
||||
fn to_biguint(&self) -> BigUint {
|
||||
self.0[0].to_biguint() + F::order() * self.0[1].to_biguint()
|
||||
}
|
||||
|
||||
fn from_canonical_u64(n: u64) -> Self {
|
||||
F::from_canonical_u64(n).into()
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssi
|
||||
|
||||
use num::bigint::BigUint;
|
||||
use num::traits::Pow;
|
||||
use num::Integer;
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@ -50,27 +51,29 @@ impl<F: Extendable<4>> From<F> for QuarticExtension<F> {
|
||||
}
|
||||
|
||||
impl<F: Extendable<4>> Field for QuarticExtension<F> {
|
||||
type PrimeField = F;
|
||||
|
||||
const ZERO: Self = Self([F::ZERO; 4]);
|
||||
const ONE: Self = Self([F::ONE, F::ZERO, F::ZERO, F::ZERO]);
|
||||
const TWO: Self = Self([F::TWO, F::ZERO, F::ZERO, F::ZERO]);
|
||||
const NEG_ONE: Self = Self([F::NEG_ONE, F::ZERO, F::ZERO, F::ZERO]);
|
||||
|
||||
const CHARACTERISTIC: u64 = F::ORDER;
|
||||
|
||||
// `p^4 - 1 = (p - 1)(p + 1)(p^2 + 1)`. The `p - 1` term has a two-adicity of `F::TWO_ADICITY`.
|
||||
// As long as `F::TWO_ADICITY >= 2`, `p` can be written as `4n + 1`, so `p + 1` can be written as
|
||||
// `2(2n + 1)`, which has a 2-adicity of 1. A similar argument can show that `p^2 + 1` also has
|
||||
// a 2-adicity of 1.
|
||||
const TWO_ADICITY: usize = F::TWO_ADICITY + 2;
|
||||
const CHARACTERISTIC_TWO_ADICITY: usize = F::CHARACTERISTIC_TWO_ADICITY;
|
||||
|
||||
const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR);
|
||||
const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR);
|
||||
|
||||
const BITS: usize = F::BITS * 4;
|
||||
|
||||
fn order() -> BigUint {
|
||||
F::order().pow(4u32)
|
||||
}
|
||||
fn characteristic() -> BigUint {
|
||||
F::characteristic()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn square(&self) -> Self {
|
||||
@ -104,6 +107,26 @@ impl<F: Extendable<4>> Field for QuarticExtension<F> {
|
||||
))
|
||||
}
|
||||
|
||||
fn from_biguint(n: BigUint) -> Self {
|
||||
let (rest, first) = n.div_rem(&F::order());
|
||||
let (rest, second) = rest.div_rem(&F::order());
|
||||
let (rest, third) = rest.div_rem(&F::order());
|
||||
Self([
|
||||
F::from_biguint(first),
|
||||
F::from_biguint(second),
|
||||
F::from_biguint(third),
|
||||
F::from_biguint(rest),
|
||||
])
|
||||
}
|
||||
|
||||
fn to_biguint(&self) -> BigUint {
|
||||
let mut result = self.0[3].to_biguint();
|
||||
result = result * F::order() + self.0[2].to_biguint();
|
||||
result = result * F::order() + self.0[1].to_biguint();
|
||||
result = result * F::order() + self.0[0].to_biguint();
|
||||
result
|
||||
}
|
||||
|
||||
fn from_canonical_u64(n: u64) -> Self {
|
||||
F::from_canonical_u64(n).into()
|
||||
}
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::field::extension_field::algebra::ExtensionAlgebra;
|
||||
@ -33,6 +32,7 @@ impl<const D: usize> ExtensionTarget<D> {
|
||||
let arr = self.to_target_array();
|
||||
let k = (F::order() - 1u32) / (D as u64);
|
||||
let z0 = F::Extension::W.exp_biguint(&(k * count as u64));
|
||||
#[allow(clippy::needless_collect)]
|
||||
let zs = z0
|
||||
.powers()
|
||||
.take(D)
|
||||
|
||||
@ -5,8 +5,8 @@ use unroll::unroll_for_loops;
|
||||
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::packable::Packable;
|
||||
use crate::field::packed_field::{PackedField, Singleton};
|
||||
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
|
||||
use crate::field::packed_field::PackedField;
|
||||
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
|
||||
use crate::util::{log2_strict, reverse_index_bits};
|
||||
|
||||
pub(crate) type FftRootTable<F> = Vec<Vec<F>>;
|
||||
@ -38,7 +38,7 @@ fn fft_dispatch<F: Field>(
|
||||
zero_factor: Option<usize>,
|
||||
root_table: Option<&FftRootTable<F>>,
|
||||
) -> Vec<F> {
|
||||
let computed_root_table = if let Some(_) = root_table {
|
||||
let computed_root_table = if root_table.is_some() {
|
||||
None
|
||||
} else {
|
||||
Some(fft_root_table(input.len()))
|
||||
@ -98,12 +98,12 @@ pub fn ifft_with_options<F: Field>(
|
||||
/// Generic FFT implementation that works with both scalar and packed inputs.
|
||||
#[unroll_for_loops]
|
||||
fn fft_classic_simd<P: PackedField>(
|
||||
values: &mut [P::FieldType],
|
||||
values: &mut [P::Scalar],
|
||||
r: usize,
|
||||
lg_n: usize,
|
||||
root_table: &FftRootTable<P::FieldType>,
|
||||
root_table: &FftRootTable<P::Scalar>,
|
||||
) {
|
||||
let lg_packed_width = P::LOG2_WIDTH; // 0 when P is a scalar.
|
||||
let lg_packed_width = log2_strict(P::WIDTH); // 0 when P is a scalar.
|
||||
let packed_values = P::pack_slice_mut(values);
|
||||
let packed_n = packed_values.len();
|
||||
debug_assert!(packed_n == 1 << (lg_n - lg_packed_width));
|
||||
@ -121,19 +121,18 @@ fn fft_classic_simd<P: PackedField>(
|
||||
let half_m = 1 << lg_half_m;
|
||||
|
||||
// Set omega to root_table[lg_half_m][0..half_m] but repeated.
|
||||
let mut omega_vec = P::zero().to_vec();
|
||||
for j in 0..omega_vec.len() {
|
||||
omega_vec[j] = root_table[lg_half_m][j % half_m];
|
||||
let mut omega = P::ZERO;
|
||||
for (j, omega_j) in omega.as_slice_mut().iter_mut().enumerate() {
|
||||
*omega_j = root_table[lg_half_m][j % half_m];
|
||||
}
|
||||
let omega = P::from_slice(&omega_vec[..]);
|
||||
|
||||
for k in (0..packed_n).step_by(2) {
|
||||
// We have two vectors and want to do math on pairs of adjacent elements (or for
|
||||
// lg_half_m > 0, pairs of adjacent blocks of elements). .interleave does the
|
||||
// appropriate shuffling and is its own inverse.
|
||||
let (u, v) = packed_values[k].interleave(packed_values[k + 1], lg_half_m);
|
||||
let (u, v) = packed_values[k].interleave(packed_values[k + 1], half_m);
|
||||
let t = omega * v;
|
||||
(packed_values[k], packed_values[k + 1]) = (u + t).interleave(u - t, lg_half_m);
|
||||
(packed_values[k], packed_values[k + 1]) = (u + t).interleave(u - t, half_m);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -197,13 +196,13 @@ pub(crate) fn fft_classic<F: Field>(input: &[F], r: usize, root_table: &FftRootT
|
||||
}
|
||||
}
|
||||
|
||||
let lg_packed_width = <F as Packable>::PackedType::LOG2_WIDTH;
|
||||
let lg_packed_width = log2_strict(<F as Packable>::Packing::WIDTH);
|
||||
if lg_n <= lg_packed_width {
|
||||
// Need the slice to be at least the width of two packed vectors for the vectorized version
|
||||
// to work. Do this tiny problem in scalar.
|
||||
fft_classic_simd::<Singleton<F>>(&mut values[..], r, lg_n, &root_table);
|
||||
fft_classic_simd::<F>(&mut values[..], r, lg_n, root_table);
|
||||
} else {
|
||||
fft_classic_simd::<<F as Packable>::PackedType>(&mut values[..], r, lg_n, &root_table);
|
||||
fft_classic_simd::<<F as Packable>::Packing>(&mut values[..], r, lg_n, root_table);
|
||||
}
|
||||
values
|
||||
}
|
||||
@ -213,19 +212,23 @@ mod tests {
|
||||
use crate::field::fft::{fft, fft_with_options, ifft};
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
|
||||
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
|
||||
use crate::util::{log2_ceil, log2_strict};
|
||||
|
||||
#[test]
|
||||
fn fft_and_ifft() {
|
||||
type F = GoldilocksField;
|
||||
let degree = 200;
|
||||
let degree_padded = log2_ceil(degree);
|
||||
let mut coefficients = Vec::new();
|
||||
for i in 0..degree {
|
||||
coefficients.push(F::from_canonical_usize(i * 1337 % 100));
|
||||
}
|
||||
let coefficients = PolynomialCoeffs::new_padded(coefficients);
|
||||
let degree = 200usize;
|
||||
let degree_padded = degree.next_power_of_two();
|
||||
|
||||
// Create a vector of coeffs; the first degree of them are
|
||||
// "random", the last degree_padded-degree of them are zero.
|
||||
let coeffs = (0..degree)
|
||||
.map(|i| F::from_canonical_usize(i * 1337 % 100))
|
||||
.chain(std::iter::repeat(F::ZERO).take(degree_padded - degree))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(coeffs.len(), degree_padded);
|
||||
let coefficients = PolynomialCoeffs { coeffs };
|
||||
|
||||
let points = fft(&coefficients);
|
||||
assert_eq!(points, evaluate_naive(&coefficients));
|
||||
@ -263,7 +266,7 @@ mod tests {
|
||||
|
||||
let values = subgroup
|
||||
.into_iter()
|
||||
.map(|x| evaluate_at_naive(&coefficients, x))
|
||||
.map(|x| evaluate_at_naive(coefficients, x))
|
||||
.collect();
|
||||
PolynomialValues::new(values)
|
||||
}
|
||||
@ -272,8 +275,8 @@ mod tests {
|
||||
let mut sum = F::ZERO;
|
||||
let mut point_power = F::ONE;
|
||||
for &c in &coefficients.coeffs {
|
||||
sum = sum + c * point_power;
|
||||
point_power = point_power * point;
|
||||
sum += c * point_power;
|
||||
point_power *= point;
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
@ -13,12 +13,15 @@ macro_rules! test_field_arithmetic {
|
||||
|
||||
#[test]
|
||||
fn batch_inversion() {
|
||||
let xs = (1..=3)
|
||||
.map(|i| <$field>::from_canonical_u64(i))
|
||||
.collect::<Vec<_>>();
|
||||
let invs = <$field>::batch_multiplicative_inverse(&xs);
|
||||
for (x, inv) in xs.into_iter().zip(invs) {
|
||||
assert_eq!(x * inv, <$field>::ONE);
|
||||
for n in 0..20 {
|
||||
let xs = (1..=n as u64)
|
||||
.map(|i| <$field>::from_canonical_u64(i))
|
||||
.collect::<Vec<_>>();
|
||||
let invs = <$field>::batch_multiplicative_inverse(&xs);
|
||||
assert_eq!(invs.len(), n);
|
||||
for (x, inv) in xs.into_iter().zip(invs) {
|
||||
assert_eq!(x * inv, <$field>::ONE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -81,10 +84,24 @@ macro_rules! test_field_arithmetic {
|
||||
assert_eq!(base.exp_biguint(&pow), base.exp_biguint(&big_pow));
|
||||
assert_ne!(base.exp_biguint(&pow), base.exp_biguint(&big_pow_wrong));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inverses() {
|
||||
type F = $field;
|
||||
|
||||
let x = F::rand();
|
||||
let x1 = x.inverse();
|
||||
let x2 = x1.inverse();
|
||||
let x3 = x2.inverse();
|
||||
|
||||
assert_eq!(x, x2);
|
||||
assert_eq!(x1, x3);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[allow(clippy::eq_op)]
|
||||
pub(crate) fn test_add_neg_sub_mul<BF: Extendable<D>, const D: usize>() {
|
||||
let x = BF::Extension::rand();
|
||||
let y = BF::Extension::rand();
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
use std::convert::TryInto;
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::hash::Hash;
|
||||
use std::iter::{Product, Sum};
|
||||
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||
|
||||
use num::bigint::BigUint;
|
||||
use num::{Integer, One, Zero};
|
||||
use num::{Integer, One, ToPrimitive, Zero};
|
||||
use rand::Rng;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
@ -43,24 +42,28 @@ pub trait Field:
|
||||
+ Serialize
|
||||
+ DeserializeOwned
|
||||
{
|
||||
type PrimeField: PrimeField;
|
||||
|
||||
const ZERO: Self;
|
||||
const ONE: Self;
|
||||
const TWO: Self;
|
||||
const NEG_ONE: Self;
|
||||
|
||||
const CHARACTERISTIC: u64;
|
||||
|
||||
/// The 2-adicity of this field's multiplicative group.
|
||||
const TWO_ADICITY: usize;
|
||||
|
||||
/// The field's characteristic and it's 2-adicity.
|
||||
/// Set to `None` when the characteristic doesn't fit in a u64.
|
||||
const CHARACTERISTIC_TWO_ADICITY: usize;
|
||||
|
||||
/// Generator of the entire multiplicative group, i.e. all non-zero elements.
|
||||
const MULTIPLICATIVE_GROUP_GENERATOR: Self;
|
||||
/// Generator of a multiplicative subgroup of order `2^TWO_ADICITY`.
|
||||
const POWER_OF_TWO_GENERATOR: Self;
|
||||
|
||||
/// The bit length of the field order.
|
||||
const BITS: usize;
|
||||
|
||||
fn order() -> BigUint;
|
||||
fn characteristic() -> BigUint;
|
||||
|
||||
#[inline]
|
||||
fn is_zero(&self) -> bool {
|
||||
@ -92,6 +95,10 @@ pub trait Field:
|
||||
self.square() * *self
|
||||
}
|
||||
|
||||
fn triple(&self) -> Self {
|
||||
*self * (Self::ONE + Self::TWO)
|
||||
}
|
||||
|
||||
/// Compute the multiplicative inverse of this field element.
|
||||
fn try_inverse(&self) -> Option<Self>;
|
||||
|
||||
@ -103,34 +110,91 @@ pub trait Field:
|
||||
// This is Montgomery's trick. At a high level, we invert the product of the given field
|
||||
// elements, then derive the individual inverses from that via multiplication.
|
||||
|
||||
// The usual Montgomery trick involves calculating an array of cumulative products,
|
||||
// resulting in a long dependency chain. To increase instruction-level parallelism, we
|
||||
// compute WIDTH separate cumulative product arrays that only meet at the end.
|
||||
|
||||
// Higher WIDTH increases instruction-level parallelism, but too high a value will cause us
|
||||
// to run out of registers.
|
||||
const WIDTH: usize = 4;
|
||||
// JN note: WIDTH is 4. The code is specialized to this value and will need
|
||||
// modification if it is changed. I tried to make it more generic, but Rust's const
|
||||
// generics are not yet good enough.
|
||||
|
||||
// Handle special cases. Paradoxically, below is repetitive but concise.
|
||||
// The branches should be very predictable.
|
||||
let n = x.len();
|
||||
if n == 0 {
|
||||
return Vec::new();
|
||||
}
|
||||
if n == 1 {
|
||||
} else if n == 1 {
|
||||
return vec![x[0].inverse()];
|
||||
} else if n == 2 {
|
||||
let x01 = x[0] * x[1];
|
||||
let x01inv = x01.inverse();
|
||||
return vec![x01inv * x[1], x01inv * x[0]];
|
||||
} else if n == 3 {
|
||||
let x01 = x[0] * x[1];
|
||||
let x012 = x01 * x[2];
|
||||
let x012inv = x012.inverse();
|
||||
let x01inv = x012inv * x[2];
|
||||
return vec![x01inv * x[1], x01inv * x[0], x012inv * x01];
|
||||
}
|
||||
debug_assert!(n >= WIDTH);
|
||||
|
||||
// Fill buf with cumulative product of x.
|
||||
let mut buf = Vec::with_capacity(n);
|
||||
let mut cumul_prod = x[0];
|
||||
buf.push(cumul_prod);
|
||||
for i in 1..n {
|
||||
cumul_prod *= x[i];
|
||||
buf.push(cumul_prod);
|
||||
// Buf is reused for a few things to save allocations.
|
||||
// Fill buf with cumulative product of x, only taking every 4th value. Concretely, buf will
|
||||
// be [
|
||||
// x[0], x[1], x[2], x[3],
|
||||
// x[0] * x[4], x[1] * x[5], x[2] * x[6], x[3] * x[7],
|
||||
// x[0] * x[4] * x[8], x[1] * x[5] * x[9], x[2] * x[6] * x[10], x[3] * x[7] * x[11],
|
||||
// ...
|
||||
// ].
|
||||
// If n is not a multiple of WIDTH, the result is truncated from the end. For example,
|
||||
// for n == 5, we get [x[0], x[1], x[2], x[3], x[0] * x[4]].
|
||||
let mut buf: Vec<Self> = Vec::with_capacity(n);
|
||||
// cumul_prod holds the last WIDTH elements of buf. This is redundant, but it's how we
|
||||
// convince LLVM to keep the values in the registers.
|
||||
let mut cumul_prod: [Self; WIDTH] = x[..WIDTH].try_into().unwrap();
|
||||
buf.extend(cumul_prod);
|
||||
for (i, &xi) in x[WIDTH..].iter().enumerate() {
|
||||
cumul_prod[i % WIDTH] *= xi;
|
||||
buf.push(cumul_prod[i % WIDTH]);
|
||||
}
|
||||
debug_assert_eq!(buf.len(), n);
|
||||
|
||||
// At this stage buf contains the the cumulative product of x. We reuse the buffer for
|
||||
// efficiency. At the end of the loop, it is filled with inverses of x.
|
||||
let mut a_inv = cumul_prod.inverse();
|
||||
buf[n - 1] = buf[n - 2] * a_inv;
|
||||
for i in (1..n - 1).rev() {
|
||||
a_inv = x[i + 1] * a_inv;
|
||||
// buf[i - 1] has not been written to by this loop, so it equals x[0] * ... x[n - 1].
|
||||
buf[i] = buf[i - 1] * a_inv;
|
||||
let mut a_inv = {
|
||||
// This is where the four dependency chains meet.
|
||||
// Take the last four elements of buf and invert them all.
|
||||
let c01 = cumul_prod[0] * cumul_prod[1];
|
||||
let c23 = cumul_prod[2] * cumul_prod[3];
|
||||
let c0123 = c01 * c23;
|
||||
let c0123inv = c0123.inverse();
|
||||
let c01inv = c0123inv * c23;
|
||||
let c23inv = c0123inv * c01;
|
||||
[
|
||||
c01inv * cumul_prod[1],
|
||||
c01inv * cumul_prod[0],
|
||||
c23inv * cumul_prod[3],
|
||||
c23inv * cumul_prod[2],
|
||||
]
|
||||
};
|
||||
|
||||
for i in (WIDTH..n).rev() {
|
||||
// buf[i - WIDTH] has not been written to by this loop, so it equals
|
||||
// x[i % WIDTH] * x[i % WIDTH + WIDTH] * ... * x[i - WIDTH].
|
||||
buf[i] = buf[i - WIDTH] * a_inv[i % WIDTH];
|
||||
// buf[i] now holds the inverse of x[i].
|
||||
a_inv[i % WIDTH] *= x[i];
|
||||
}
|
||||
buf[0] = x[1] * a_inv;
|
||||
for i in (0..WIDTH).rev() {
|
||||
buf[i] = a_inv[i];
|
||||
}
|
||||
|
||||
for (&bi, &xi) in buf.iter().zip(x) {
|
||||
// Sanity check only.
|
||||
debug_assert_eq!(bi * xi, Self::ONE);
|
||||
}
|
||||
|
||||
buf
|
||||
}
|
||||
|
||||
@ -142,29 +206,31 @@ pub trait Field:
|
||||
// exp exceeds t, we repeatedly multiply by 2^-t and reduce
|
||||
// exp until it's in the right range.
|
||||
|
||||
let p = Self::CHARACTERISTIC;
|
||||
if let Some(p) = Self::characteristic().to_u64() {
|
||||
// NB: The only reason this is split into two cases is to save
|
||||
// the multiplication (and possible calculation of
|
||||
// inverse_2_pow_adicity) in the usual case that exp <=
|
||||
// TWO_ADICITY. Can remove the branch and simplify if that
|
||||
// saving isn't worth it.
|
||||
|
||||
// NB: The only reason this is split into two cases is to save
|
||||
// the multiplication (and possible calculation of
|
||||
// inverse_2_pow_adicity) in the usual case that exp <=
|
||||
// TWO_ADICITY. Can remove the branch and simplify if that
|
||||
// saving isn't worth it.
|
||||
if exp > Self::CHARACTERISTIC_TWO_ADICITY {
|
||||
// NB: This should be a compile-time constant
|
||||
let inverse_2_pow_adicity: Self =
|
||||
Self::from_canonical_u64(p - ((p - 1) >> Self::CHARACTERISTIC_TWO_ADICITY));
|
||||
|
||||
if exp > Self::PrimeField::TWO_ADICITY {
|
||||
// NB: This should be a compile-time constant
|
||||
let inverse_2_pow_adicity: Self =
|
||||
Self::from_canonical_u64(p - ((p - 1) >> Self::PrimeField::TWO_ADICITY));
|
||||
let mut res = inverse_2_pow_adicity;
|
||||
let mut e = exp - Self::CHARACTERISTIC_TWO_ADICITY;
|
||||
|
||||
let mut res = inverse_2_pow_adicity;
|
||||
let mut e = exp - Self::PrimeField::TWO_ADICITY;
|
||||
|
||||
while e > Self::PrimeField::TWO_ADICITY {
|
||||
res *= inverse_2_pow_adicity;
|
||||
e -= Self::PrimeField::TWO_ADICITY;
|
||||
while e > Self::CHARACTERISTIC_TWO_ADICITY {
|
||||
res *= inverse_2_pow_adicity;
|
||||
e -= Self::CHARACTERISTIC_TWO_ADICITY;
|
||||
}
|
||||
res * Self::from_canonical_u64(p - ((p - 1) >> e))
|
||||
} else {
|
||||
Self::from_canonical_u64(p - ((p - 1) >> exp))
|
||||
}
|
||||
res * Self::from_canonical_u64(p - ((p - 1) >> e))
|
||||
} else {
|
||||
Self::from_canonical_u64(p - ((p - 1) >> exp))
|
||||
Self::TWO.inverse().exp_u64(exp as u64)
|
||||
}
|
||||
}
|
||||
|
||||
@ -206,6 +272,11 @@ pub trait Field:
|
||||
subgroup.into_iter().map(|x| x * shift).collect()
|
||||
}
|
||||
|
||||
// TODO: move these to a new `PrimeField` trait (for all prime fields, not just 64-bit ones)
|
||||
fn from_biguint(n: BigUint) -> Self;
|
||||
|
||||
fn to_biguint(&self) -> BigUint;
|
||||
|
||||
fn from_canonical_u64(n: u64) -> Self;
|
||||
|
||||
fn from_canonical_u32(n: u32) -> Self {
|
||||
@ -274,7 +345,7 @@ pub trait Field:
|
||||
}
|
||||
|
||||
fn kth_root_u64(&self, k: u64) -> Self {
|
||||
let p = Self::order().clone();
|
||||
let p = Self::order();
|
||||
let p_minus_1 = &p - 1u32;
|
||||
debug_assert!(
|
||||
Self::is_monomial_permutation_u64(k),
|
||||
@ -356,6 +427,7 @@ pub trait PrimeField: Field {
|
||||
unsafe { self.sub_canonical_u64(1) }
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Equivalent to *self + Self::from_canonical_u64(rhs), but may be cheaper. The caller must
|
||||
/// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this
|
||||
/// precondition is not met. It is marked unsafe for this reason.
|
||||
@ -365,6 +437,7 @@ pub trait PrimeField: Field {
|
||||
*self + Self::from_canonical_u64(rhs)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
/// Equivalent to *self - Self::from_canonical_u64(rhs), but may be cheaper. The caller must
|
||||
/// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this
|
||||
/// precondition is not met. It is marked unsafe for this reason.
|
||||
|
||||
@ -4,7 +4,7 @@ use std::hash::{Hash, Hasher};
|
||||
use std::iter::{Product, Sum};
|
||||
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||
|
||||
use num::BigUint;
|
||||
use num::{BigUint, Integer};
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@ -62,15 +62,13 @@ impl Debug for GoldilocksField {
|
||||
}
|
||||
|
||||
impl Field for GoldilocksField {
|
||||
type PrimeField = Self;
|
||||
|
||||
const ZERO: Self = Self(0);
|
||||
const ONE: Self = Self(1);
|
||||
const TWO: Self = Self(2);
|
||||
const NEG_ONE: Self = Self(Self::ORDER - 1);
|
||||
const CHARACTERISTIC: u64 = Self::ORDER;
|
||||
|
||||
const TWO_ADICITY: usize = 32;
|
||||
const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY;
|
||||
|
||||
// Sage: `g = GF(p).multiplicative_generator()`
|
||||
const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(7);
|
||||
@ -82,15 +80,28 @@ impl Field for GoldilocksField {
|
||||
// ```
|
||||
const POWER_OF_TWO_GENERATOR: Self = Self(1753635133440165772);
|
||||
|
||||
const BITS: usize = 64;
|
||||
|
||||
fn order() -> BigUint {
|
||||
Self::ORDER.into()
|
||||
}
|
||||
fn characteristic() -> BigUint {
|
||||
Self::order()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn try_inverse(&self) -> Option<Self> {
|
||||
try_inverse_u64(self)
|
||||
}
|
||||
|
||||
fn from_biguint(n: BigUint) -> Self {
|
||||
Self(n.mod_floor(&Self::order()).to_u64_digits()[0])
|
||||
}
|
||||
|
||||
fn to_biguint(&self) -> BigUint {
|
||||
self.to_canonical_u64().into()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_canonical_u64(n: u64) -> Self {
|
||||
debug_assert!(n < Self::ORDER);
|
||||
@ -312,6 +323,7 @@ impl RichField for GoldilocksField {}
|
||||
#[inline(always)]
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 {
|
||||
use std::arch::asm;
|
||||
let res_wrapped: u64;
|
||||
let adjustment: u64;
|
||||
asm!(
|
||||
@ -352,6 +364,7 @@ unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 {
|
||||
#[inline(always)]
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
unsafe fn sub_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 {
|
||||
use std::arch::asm;
|
||||
let res_wrapped: u64;
|
||||
let adjustment: u64;
|
||||
asm!(
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
use crate::field::fft::ifft;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
|
||||
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
|
||||
use crate::util::log2_ceil;
|
||||
|
||||
/// Computes the unique degree < n interpolant of an arbitrary list of n (point, value) pairs.
|
||||
@ -80,7 +80,7 @@ mod tests {
|
||||
use crate::field::extension_field::quartic::QuarticExtension;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::polynomial::polynomial::PolynomialCoeffs;
|
||||
use crate::polynomial::PolynomialCoeffs;
|
||||
|
||||
#[test]
|
||||
fn interpolant_random() {
|
||||
|
||||
@ -7,7 +7,8 @@ pub(crate) mod interpolation;
|
||||
mod inversion;
|
||||
pub(crate) mod packable;
|
||||
pub(crate) mod packed_field;
|
||||
pub mod secp256k1;
|
||||
pub mod secp256k1_base;
|
||||
pub mod secp256k1_scalar;
|
||||
|
||||
#[cfg(target_feature = "avx2")]
|
||||
pub(crate) mod packed_avx2;
|
||||
|
||||
@ -1,18 +1,18 @@
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::packed_field::{PackedField, Singleton};
|
||||
use crate::field::packed_field::PackedField;
|
||||
|
||||
/// Points us to the default packing for a particular field. There may me multiple choices of
|
||||
/// PackedField for a particular Field (e.g. Singleton works for all fields), but this is the
|
||||
/// PackedField for a particular Field (e.g. every Field is also a PackedField), but this is the
|
||||
/// recommended one. The recommended packing varies by target_arch and target_feature.
|
||||
pub trait Packable: Field {
|
||||
type PackedType: PackedField<FieldType = Self>;
|
||||
type Packing: PackedField<Scalar = Self>;
|
||||
}
|
||||
|
||||
impl<F: Field> Packable for F {
|
||||
default type PackedType = Singleton<Self>;
|
||||
default type Packing = Self;
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "avx2")]
|
||||
impl Packable for crate::field::goldilocks_field::GoldilocksField {
|
||||
type PackedType = crate::field::packed_avx2::PackedGoldilocksAVX2;
|
||||
type Packing = crate::field::packed_avx2::PackedGoldilocksAvx2;
|
||||
}
|
||||
|
||||
@ -2,20 +2,20 @@ use core::arch::x86_64::*;
|
||||
use std::fmt;
|
||||
use std::fmt::{Debug, Formatter};
|
||||
use std::iter::{Product, Sum};
|
||||
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||
|
||||
use crate::field::field_types::PrimeField;
|
||||
use crate::field::packed_avx2::common::{
|
||||
add_no_canonicalize_64_64s_s, epsilon, field_order, ReducibleAVX2,
|
||||
add_no_canonicalize_64_64s_s, epsilon, field_order, shift, ReducibleAvx2,
|
||||
};
|
||||
use crate::field::packed_field::PackedField;
|
||||
|
||||
// PackedPrimeField wraps an array of four u64s, with the new and get methods to convert that
|
||||
// Avx2PrimeField wraps an array of four u64s, with the new and get methods to convert that
|
||||
// array to and from __m256i, which is the type we actually operate on. This indirection is a
|
||||
// terrible trick to change PackedPrimeField's alignment.
|
||||
// We'd like to be able to cast slices of PrimeField to slices of PackedPrimeField. Rust
|
||||
// terrible trick to change Avx2PrimeField's alignment.
|
||||
// We'd like to be able to cast slices of PrimeField to slices of Avx2PrimeField. Rust
|
||||
// aligns __m256i to 32 bytes but PrimeField has a lower alignment. That alignment extends to
|
||||
// PackedPrimeField and it appears that it cannot be lowered with #[repr(C, blah)]. It is
|
||||
// Avx2PrimeField and it appears that it cannot be lowered with #[repr(C, blah)]. It is
|
||||
// important for Rust not to assume 32-byte alignment, so we cannot wrap __m256i directly.
|
||||
// There are two versions of vectorized load/store instructions on x86: aligned (vmovaps and
|
||||
// friends) and unaligned (vmovups etc.). The difference between them is that aligned loads and
|
||||
@ -23,12 +23,12 @@ use crate::field::packed_field::PackedField;
|
||||
// were faster, and although this is no longer the case, compilers prefer the aligned versions if
|
||||
// they know that the address is aligned. Using aligned instructions on unaligned addresses leads to
|
||||
// bugs that can be frustrating to diagnose. Hence, we can't have Rust assuming alignment, and
|
||||
// therefore PackedPrimeField wraps [F; 4] and not __m256i.
|
||||
// therefore Avx2PrimeField wraps [F; 4] and not __m256i.
|
||||
#[derive(Copy, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PackedPrimeField<F: ReducibleAVX2>(pub [F; 4]);
|
||||
pub struct Avx2PrimeField<F: ReducibleAvx2>(pub [F; 4]);
|
||||
|
||||
impl<F: ReducibleAVX2> PackedPrimeField<F> {
|
||||
impl<F: ReducibleAvx2> Avx2PrimeField<F> {
|
||||
#[inline]
|
||||
fn new(x: __m256i) -> Self {
|
||||
let mut obj = Self([F::ZERO; 4]);
|
||||
@ -43,84 +43,111 @@ impl<F: ReducibleAVX2> PackedPrimeField<F> {
|
||||
let ptr = (&self.0).as_ptr().cast::<__m256i>();
|
||||
unsafe { _mm256_loadu_si256(ptr) }
|
||||
}
|
||||
|
||||
/// Addition that assumes x + y < 2^64 + F::ORDER. May return incorrect results if this
|
||||
/// condition is not met, hence it is marked unsafe.
|
||||
#[inline]
|
||||
pub unsafe fn add_canonical_u64(&self, rhs: __m256i) -> Self {
|
||||
Self::new(add_canonical_u64::<F>(self.get(), rhs))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: ReducibleAVX2> Add<Self> for PackedPrimeField<F> {
|
||||
impl<F: ReducibleAvx2> Add<Self> for Avx2PrimeField<F> {
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn add(self, rhs: Self) -> Self {
|
||||
Self::new(unsafe { add::<F>(self.get(), rhs.get()) })
|
||||
}
|
||||
}
|
||||
impl<F: ReducibleAVX2> Add<F> for PackedPrimeField<F> {
|
||||
impl<F: ReducibleAvx2> Add<F> for Avx2PrimeField<F> {
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn add(self, rhs: F) -> Self {
|
||||
self + Self::broadcast(rhs)
|
||||
self + Self::from(rhs)
|
||||
}
|
||||
}
|
||||
impl<F: ReducibleAVX2> AddAssign<Self> for PackedPrimeField<F> {
|
||||
impl<F: ReducibleAvx2> Add<Avx2PrimeField<F>> for <Avx2PrimeField<F> as PackedField>::Scalar {
|
||||
type Output = Avx2PrimeField<F>;
|
||||
#[inline]
|
||||
fn add(self, rhs: Self::Output) -> Self::Output {
|
||||
Self::Output::from(self) + rhs
|
||||
}
|
||||
}
|
||||
impl<F: ReducibleAvx2> AddAssign<Self> for Avx2PrimeField<F> {
|
||||
#[inline]
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
*self = *self + rhs;
|
||||
}
|
||||
}
|
||||
impl<F: ReducibleAVX2> AddAssign<F> for PackedPrimeField<F> {
|
||||
impl<F: ReducibleAvx2> AddAssign<F> for Avx2PrimeField<F> {
|
||||
#[inline]
|
||||
fn add_assign(&mut self, rhs: F) {
|
||||
*self = *self + rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: ReducibleAVX2> Debug for PackedPrimeField<F> {
|
||||
impl<F: ReducibleAvx2> Debug for Avx2PrimeField<F> {
|
||||
#[inline]
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "({:?})", self.get())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: ReducibleAVX2> Default for PackedPrimeField<F> {
|
||||
impl<F: ReducibleAvx2> Default for Avx2PrimeField<F> {
|
||||
#[inline]
|
||||
fn default() -> Self {
|
||||
Self::zero()
|
||||
Self::ZERO
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: ReducibleAVX2> Mul<Self> for PackedPrimeField<F> {
|
||||
impl<F: ReducibleAvx2> Div<F> for Avx2PrimeField<F> {
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn div(self, rhs: F) -> Self {
|
||||
self * rhs.inverse()
|
||||
}
|
||||
}
|
||||
impl<F: ReducibleAvx2> DivAssign<F> for Avx2PrimeField<F> {
|
||||
#[inline]
|
||||
fn div_assign(&mut self, rhs: F) {
|
||||
*self *= rhs.inverse();
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: ReducibleAvx2> From<F> for Avx2PrimeField<F> {
|
||||
fn from(x: F) -> Self {
|
||||
Self([x; 4])
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: ReducibleAvx2> Mul<Self> for Avx2PrimeField<F> {
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn mul(self, rhs: Self) -> Self {
|
||||
Self::new(unsafe { mul::<F>(self.get(), rhs.get()) })
|
||||
}
|
||||
}
|
||||
impl<F: ReducibleAVX2> Mul<F> for PackedPrimeField<F> {
|
||||
impl<F: ReducibleAvx2> Mul<F> for Avx2PrimeField<F> {
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn mul(self, rhs: F) -> Self {
|
||||
self * Self::broadcast(rhs)
|
||||
self * Self::from(rhs)
|
||||
}
|
||||
}
|
||||
impl<F: ReducibleAVX2> MulAssign<Self> for PackedPrimeField<F> {
|
||||
impl<F: ReducibleAvx2> Mul<Avx2PrimeField<F>> for <Avx2PrimeField<F> as PackedField>::Scalar {
|
||||
type Output = Avx2PrimeField<F>;
|
||||
#[inline]
|
||||
fn mul(self, rhs: Avx2PrimeField<F>) -> Self::Output {
|
||||
Self::Output::from(self) * rhs
|
||||
}
|
||||
}
|
||||
impl<F: ReducibleAvx2> MulAssign<Self> for Avx2PrimeField<F> {
|
||||
#[inline]
|
||||
fn mul_assign(&mut self, rhs: Self) {
|
||||
*self = *self * rhs;
|
||||
}
|
||||
}
|
||||
impl<F: ReducibleAVX2> MulAssign<F> for PackedPrimeField<F> {
|
||||
impl<F: ReducibleAvx2> MulAssign<F> for Avx2PrimeField<F> {
|
||||
#[inline]
|
||||
fn mul_assign(&mut self, rhs: F) {
|
||||
*self = *self * rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: ReducibleAVX2> Neg for PackedPrimeField<F> {
|
||||
impl<F: ReducibleAvx2> Neg for Avx2PrimeField<F> {
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn neg(self) -> Self {
|
||||
@ -128,52 +155,59 @@ impl<F: ReducibleAVX2> Neg for PackedPrimeField<F> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: ReducibleAVX2> Product for PackedPrimeField<F> {
|
||||
impl<F: ReducibleAvx2> Product for Avx2PrimeField<F> {
|
||||
#[inline]
|
||||
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
|
||||
iter.reduce(|x, y| x * y).unwrap_or(Self::one())
|
||||
iter.reduce(|x, y| x * y).unwrap_or(Self::ONE)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: ReducibleAVX2> PackedField for PackedPrimeField<F> {
|
||||
const LOG2_WIDTH: usize = 2;
|
||||
unsafe impl<F: ReducibleAvx2> PackedField for Avx2PrimeField<F> {
|
||||
const WIDTH: usize = 4;
|
||||
|
||||
type FieldType = F;
|
||||
type Scalar = F;
|
||||
type PackedPrimeField = Avx2PrimeField<F>;
|
||||
|
||||
const ZERO: Self = Self([F::ZERO; 4]);
|
||||
const ONE: Self = Self([F::ONE; 4]);
|
||||
|
||||
#[inline]
|
||||
fn broadcast(x: F) -> Self {
|
||||
Self([x; 4])
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_arr(arr: [F; Self::WIDTH]) -> Self {
|
||||
fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self {
|
||||
Self(arr)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn to_arr(&self) -> [F; Self::WIDTH] {
|
||||
fn as_arr(&self) -> [Self::Scalar; Self::WIDTH] {
|
||||
self.0
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_slice(slice: &[F]) -> Self {
|
||||
assert!(slice.len() == 4);
|
||||
Self([slice[0], slice[1], slice[2], slice[3]])
|
||||
fn from_slice(slice: &[Self::Scalar]) -> &Self {
|
||||
assert_eq!(slice.len(), Self::WIDTH);
|
||||
unsafe { &*slice.as_ptr().cast() }
|
||||
}
|
||||
#[inline]
|
||||
fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self {
|
||||
assert_eq!(slice.len(), Self::WIDTH);
|
||||
unsafe { &mut *slice.as_mut_ptr().cast() }
|
||||
}
|
||||
#[inline]
|
||||
fn as_slice(&self) -> &[Self::Scalar] {
|
||||
&self.0[..]
|
||||
}
|
||||
#[inline]
|
||||
fn as_slice_mut(&mut self) -> &mut [Self::Scalar] {
|
||||
&mut self.0[..]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn to_vec(&self) -> Vec<F> {
|
||||
self.0.into()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn interleave(&self, other: Self, r: usize) -> (Self, Self) {
|
||||
fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) {
|
||||
let (v0, v1) = (self.get(), other.get());
|
||||
let (res0, res1) = match r {
|
||||
0 => unsafe { interleave0(v0, v1) },
|
||||
let (res0, res1) = match block_len {
|
||||
1 => unsafe { interleave1(v0, v1) },
|
||||
2 => (v0, v1),
|
||||
_ => panic!("r cannot be more than LOG2_WIDTH"),
|
||||
2 => unsafe { interleave2(v0, v1) },
|
||||
4 => (v0, v1),
|
||||
_ => panic!("unsupported block_len"),
|
||||
};
|
||||
(Self::new(res0), Self::new(res1))
|
||||
}
|
||||
@ -184,47 +218,47 @@ impl<F: ReducibleAVX2> PackedField for PackedPrimeField<F> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: ReducibleAVX2> Sub<Self> for PackedPrimeField<F> {
|
||||
impl<F: ReducibleAvx2> Sub<Self> for Avx2PrimeField<F> {
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn sub(self, rhs: Self) -> Self {
|
||||
Self::new(unsafe { sub::<F>(self.get(), rhs.get()) })
|
||||
}
|
||||
}
|
||||
impl<F: ReducibleAVX2> Sub<F> for PackedPrimeField<F> {
|
||||
impl<F: ReducibleAvx2> Sub<F> for Avx2PrimeField<F> {
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn sub(self, rhs: F) -> Self {
|
||||
self - Self::broadcast(rhs)
|
||||
self - Self::from(rhs)
|
||||
}
|
||||
}
|
||||
impl<F: ReducibleAVX2> SubAssign<Self> for PackedPrimeField<F> {
|
||||
impl<F: ReducibleAvx2> Sub<Avx2PrimeField<F>> for <Avx2PrimeField<F> as PackedField>::Scalar {
|
||||
type Output = Avx2PrimeField<F>;
|
||||
#[inline]
|
||||
fn sub(self, rhs: Avx2PrimeField<F>) -> Self::Output {
|
||||
Self::Output::from(self) - rhs
|
||||
}
|
||||
}
|
||||
impl<F: ReducibleAvx2> SubAssign<Self> for Avx2PrimeField<F> {
|
||||
#[inline]
|
||||
fn sub_assign(&mut self, rhs: Self) {
|
||||
*self = *self - rhs;
|
||||
}
|
||||
}
|
||||
impl<F: ReducibleAVX2> SubAssign<F> for PackedPrimeField<F> {
|
||||
impl<F: ReducibleAvx2> SubAssign<F> for Avx2PrimeField<F> {
|
||||
#[inline]
|
||||
fn sub_assign(&mut self, rhs: F) {
|
||||
*self = *self - rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: ReducibleAVX2> Sum for PackedPrimeField<F> {
|
||||
impl<F: ReducibleAvx2> Sum for Avx2PrimeField<F> {
|
||||
#[inline]
|
||||
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
|
||||
iter.reduce(|x, y| x + y).unwrap_or(Self::zero())
|
||||
iter.reduce(|x, y| x + y).unwrap_or(Self::ZERO)
|
||||
}
|
||||
}
|
||||
|
||||
const SIGN_BIT: u64 = 1 << 63;
|
||||
|
||||
#[inline]
|
||||
unsafe fn sign_bit() -> __m256i {
|
||||
_mm256_set1_epi64x(SIGN_BIT as i64)
|
||||
}
|
||||
|
||||
// Resources:
|
||||
// 1. Intel Intrinsics Guide for explanation of each intrinsic:
|
||||
// https://software.intel.com/sites/landingpage/IntrinsicsGuide/
|
||||
@ -274,12 +308,6 @@ unsafe fn sign_bit() -> __m256i {
|
||||
// Notice that the above 3-value addition still only requires two calls to shift, just like our
|
||||
// 2-value addition.
|
||||
|
||||
/// Add 2^63 with overflow. Needed to emulate unsigned comparisons (see point 3. above).
|
||||
#[inline]
|
||||
unsafe fn shift(x: __m256i) -> __m256i {
|
||||
_mm256_xor_si256(x, sign_bit())
|
||||
}
|
||||
|
||||
/// Convert to canonical representation.
|
||||
/// The argument is assumed to be shifted by 1 << 63 (i.e. x_s = x + 1<<63, where x is the field
|
||||
/// value). The returned value is similarly shifted by 1 << 63 (i.e. we return y_s = y + (1<<63),
|
||||
@ -293,14 +321,6 @@ unsafe fn canonicalize_s<F: PrimeField>(x_s: __m256i) -> __m256i {
|
||||
_mm256_add_epi64(x_s, wrapback_amt)
|
||||
}
|
||||
|
||||
/// Addition that assumes x + y < 2^64 + F::ORDER.
|
||||
#[inline]
|
||||
unsafe fn add_canonical_u64<F: PrimeField>(x: __m256i, y: __m256i) -> __m256i {
|
||||
let y_s = shift(y);
|
||||
let res_s = add_no_canonicalize_64_64s_s::<F>(x, y_s);
|
||||
shift(res_s)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn add<F: PrimeField>(x: __m256i, y: __m256i) -> __m256i {
|
||||
let y_s = shift(y);
|
||||
@ -326,78 +346,94 @@ unsafe fn neg<F: PrimeField>(y: __m256i) -> __m256i {
|
||||
_mm256_sub_epi64(shift(field_order::<F>()), canonicalize_s::<F>(y_s))
|
||||
}
|
||||
|
||||
/// Full 64-bit by 64-bit multiplication. This emulated multiplication is 1.5x slower than the
|
||||
/// Full 64-bit by 64-bit multiplication. This emulated multiplication is 1.33x slower than the
|
||||
/// scalar instruction, but may be worth it if we want our data to live in vector registers.
|
||||
#[inline]
|
||||
unsafe fn mul64_64_s(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
|
||||
let x_hi = _mm256_srli_epi64(x, 32);
|
||||
let y_hi = _mm256_srli_epi64(y, 32);
|
||||
unsafe fn mul64_64(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
|
||||
// We want to move the high 32 bits to the low position. The multiplication instruction ignores
|
||||
// the high 32 bits, so it's ok to just duplicate it into the low position. This duplication can
|
||||
// be done on port 5; bitshifts run on ports 0 and 1, competing with multiplication.
|
||||
// This instruction is only provided for 32-bit floats, not integers. Idk why Intel makes the
|
||||
// distinction; the casts are free and it guarantees that the exact bit pattern is preserved.
|
||||
// Using a swizzle instruction of the wrong domain (float vs int) does not increase latency
|
||||
// since Haswell.
|
||||
let x_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x)));
|
||||
let y_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(y)));
|
||||
|
||||
// All four pairwise multiplications
|
||||
let mul_ll = _mm256_mul_epu32(x, y);
|
||||
let mul_lh = _mm256_mul_epu32(x, y_hi);
|
||||
let mul_hl = _mm256_mul_epu32(x_hi, y);
|
||||
let mul_hh = _mm256_mul_epu32(x_hi, y_hi);
|
||||
|
||||
let res_lo0_s = shift(mul_ll);
|
||||
let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 32));
|
||||
let res_lo2_s = _mm256_add_epi32(res_lo1_s, _mm256_slli_epi64(mul_hl, 32));
|
||||
// Bignum addition
|
||||
// Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow.
|
||||
let mul_ll_hi = _mm256_srli_epi64::<32>(mul_ll);
|
||||
let t0 = _mm256_add_epi64(mul_hl, mul_ll_hi);
|
||||
// Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow.
|
||||
// Also, extract high 32 bits of t0 and add to mul_hh.
|
||||
let t0_lo = _mm256_and_si256(t0, _mm256_set1_epi64x(u32::MAX.into()));
|
||||
let t0_hi = _mm256_srli_epi64::<32>(t0);
|
||||
let t1 = _mm256_add_epi64(mul_lh, t0_lo);
|
||||
let t2 = _mm256_add_epi64(mul_hh, t0_hi);
|
||||
// Lastly, extract the high 32 bits of t1 and add to t2.
|
||||
let t1_hi = _mm256_srli_epi64::<32>(t1);
|
||||
let res_hi = _mm256_add_epi64(t2, t1_hi);
|
||||
|
||||
// cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on
|
||||
// overflow and must be subtracted, not added.
|
||||
let carry0 = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s);
|
||||
let carry1 = _mm256_cmpgt_epi64(res_lo1_s, res_lo2_s);
|
||||
// Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high
|
||||
// position).
|
||||
let t1_lo = _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(t1)));
|
||||
let res_lo = _mm256_blend_epi32::<0xaa>(mul_ll, t1_lo);
|
||||
|
||||
let res_hi0 = mul_hh;
|
||||
let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 32));
|
||||
let res_hi2 = _mm256_add_epi64(res_hi1, _mm256_srli_epi64(mul_hl, 32));
|
||||
let res_hi3 = _mm256_sub_epi64(res_hi2, carry0);
|
||||
let res_hi4 = _mm256_sub_epi64(res_hi3, carry1);
|
||||
|
||||
(res_hi4, res_lo2_s)
|
||||
(res_hi, res_lo)
|
||||
}
|
||||
|
||||
/// Full 64-bit squaring. This routine is 1.2x faster than the scalar instruction.
|
||||
#[inline]
|
||||
unsafe fn square64_s(x: __m256i) -> (__m256i, __m256i) {
|
||||
let x_hi = _mm256_srli_epi64(x, 32);
|
||||
unsafe fn square64(x: __m256i) -> (__m256i, __m256i) {
|
||||
// Get high 32 bits of x. See comment in mul64_64_s.
|
||||
let x_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x)));
|
||||
|
||||
// All pairwise multiplications.
|
||||
let mul_ll = _mm256_mul_epu32(x, x);
|
||||
let mul_lh = _mm256_mul_epu32(x, x_hi);
|
||||
let mul_hh = _mm256_mul_epu32(x_hi, x_hi);
|
||||
|
||||
let res_lo0_s = shift(mul_ll);
|
||||
let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 33));
|
||||
// Bignum addition, but mul_lh is shifted by 33 bits (not 32).
|
||||
let mul_ll_hi = _mm256_srli_epi64::<33>(mul_ll);
|
||||
let t0 = _mm256_add_epi64(mul_lh, mul_ll_hi);
|
||||
let t0_hi = _mm256_srli_epi64::<31>(t0);
|
||||
let res_hi = _mm256_add_epi64(mul_hh, t0_hi);
|
||||
|
||||
// cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on
|
||||
// overflow and must be subtracted, not added.
|
||||
let carry = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s);
|
||||
// Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high
|
||||
// position).
|
||||
let mul_lh_lo = _mm256_slli_epi64::<33>(mul_lh);
|
||||
let res_lo = _mm256_add_epi64(mul_ll, mul_lh_lo);
|
||||
|
||||
let res_hi0 = mul_hh;
|
||||
let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 31));
|
||||
let res_hi2 = _mm256_sub_epi64(res_hi1, carry);
|
||||
|
||||
(res_hi2, res_lo1_s)
|
||||
(res_hi, res_lo)
|
||||
}
|
||||
|
||||
/// Multiply two integers modulo FIELD_ORDER.
|
||||
#[inline]
|
||||
unsafe fn mul<F: ReducibleAVX2>(x: __m256i, y: __m256i) -> __m256i {
|
||||
shift(F::reduce128s_s(mul64_64_s(x, y)))
|
||||
unsafe fn mul<F: ReducibleAvx2>(x: __m256i, y: __m256i) -> __m256i {
|
||||
F::reduce128(mul64_64(x, y))
|
||||
}
|
||||
|
||||
/// Square an integer modulo FIELD_ORDER.
|
||||
#[inline]
|
||||
unsafe fn square<F: ReducibleAVX2>(x: __m256i) -> __m256i {
|
||||
shift(F::reduce128s_s(square64_s(x)))
|
||||
unsafe fn square<F: ReducibleAvx2>(x: __m256i) -> __m256i {
|
||||
F::reduce128(square64(x))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn interleave0(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
|
||||
unsafe fn interleave1(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
|
||||
let a = _mm256_unpacklo_epi64(x, y);
|
||||
let b = _mm256_unpackhi_epi64(x, y);
|
||||
(a, b)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn interleave1(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
|
||||
unsafe fn interleave2(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
|
||||
let y_lo = _mm256_castsi256_si128(y); // This has 0 cost.
|
||||
|
||||
// 1 places y_lo in the high half of x; 0 would place it in the lower half.
|
||||
@ -2,8 +2,22 @@ use core::arch::x86_64::*;
|
||||
|
||||
use crate::field::field_types::PrimeField;
|
||||
|
||||
pub trait ReducibleAVX2: PrimeField {
|
||||
unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i;
|
||||
pub trait ReducibleAvx2: PrimeField {
|
||||
unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i;
|
||||
}
|
||||
|
||||
const SIGN_BIT: u64 = 1 << 63;
|
||||
|
||||
#[inline]
|
||||
unsafe fn sign_bit() -> __m256i {
|
||||
_mm256_set1_epi64x(SIGN_BIT as i64)
|
||||
}
|
||||
|
||||
/// Add 2^63 with overflow. Needed to emulate unsigned comparisons (see point 3. in
|
||||
/// packed_prime_field.rs).
|
||||
#[inline]
|
||||
pub unsafe fn shift(x: __m256i) -> __m256i {
|
||||
_mm256_xor_si256(x, sign_bit())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
||||
@ -2,19 +2,21 @@ use core::arch::x86_64::*;
|
||||
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::field::packed_avx2::common::{
|
||||
add_no_canonicalize_64_64s_s, epsilon, sub_no_canonicalize_64s_64_s, ReducibleAVX2,
|
||||
add_no_canonicalize_64_64s_s, epsilon, shift, sub_no_canonicalize_64s_64_s, ReducibleAvx2,
|
||||
};
|
||||
|
||||
/// Reduce a u128 modulo FIELD_ORDER. The input is (u64, u64), pre-shifted by 2^63. The result is
|
||||
/// similarly shifted.
|
||||
impl ReducibleAVX2 for GoldilocksField {
|
||||
impl ReducibleAvx2 for GoldilocksField {
|
||||
#[inline]
|
||||
unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i {
|
||||
let (hi0, lo0_s) = x_s;
|
||||
unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i {
|
||||
let (hi0, lo0) = x;
|
||||
let lo0_s = shift(lo0);
|
||||
let hi_hi0 = _mm256_srli_epi64(hi0, 32);
|
||||
let lo1_s = sub_no_canonicalize_64s_64_s::<GoldilocksField>(lo0_s, hi_hi0);
|
||||
let t1 = _mm256_mul_epu32(hi0, epsilon::<GoldilocksField>());
|
||||
let lo2_s = add_no_canonicalize_64_64s_s::<GoldilocksField>(t1, lo1_s);
|
||||
lo2_s
|
||||
let lo2 = shift(lo2_s);
|
||||
lo2
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,21 +1,21 @@
|
||||
mod avx2_prime_field;
|
||||
mod common;
|
||||
mod goldilocks;
|
||||
mod packed_prime_field;
|
||||
|
||||
use packed_prime_field::PackedPrimeField;
|
||||
use avx2_prime_field::Avx2PrimeField;
|
||||
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
|
||||
pub type PackedGoldilocksAVX2 = PackedPrimeField<GoldilocksField>;
|
||||
pub type PackedGoldilocksAvx2 = Avx2PrimeField<GoldilocksField>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::field::packed_avx2::common::ReducibleAVX2;
|
||||
use crate::field::packed_avx2::packed_prime_field::PackedPrimeField;
|
||||
use crate::field::packed_avx2::avx2_prime_field::Avx2PrimeField;
|
||||
use crate::field::packed_avx2::common::ReducibleAvx2;
|
||||
use crate::field::packed_field::PackedField;
|
||||
|
||||
fn test_vals_a<F: ReducibleAVX2>() -> [F; 4] {
|
||||
fn test_vals_a<F: ReducibleAvx2>() -> [F; 4] {
|
||||
[
|
||||
F::from_noncanonical_u64(14479013849828404771),
|
||||
F::from_noncanonical_u64(9087029921428221768),
|
||||
@ -23,7 +23,7 @@ mod tests {
|
||||
F::from_noncanonical_u64(5646033492608483824),
|
||||
]
|
||||
}
|
||||
fn test_vals_b<F: ReducibleAVX2>() -> [F; 4] {
|
||||
fn test_vals_b<F: ReducibleAvx2>() -> [F; 4] {
|
||||
[
|
||||
F::from_noncanonical_u64(17891926589593242302),
|
||||
F::from_noncanonical_u64(11009798273260028228),
|
||||
@ -32,17 +32,17 @@ mod tests {
|
||||
]
|
||||
}
|
||||
|
||||
fn test_add<F: ReducibleAVX2>()
|
||||
fn test_add<F: ReducibleAvx2>()
|
||||
where
|
||||
[(); PackedPrimeField::<F>::WIDTH]: ,
|
||||
[(); Avx2PrimeField::<F>::WIDTH]:,
|
||||
{
|
||||
let a_arr = test_vals_a::<F>();
|
||||
let b_arr = test_vals_b::<F>();
|
||||
|
||||
let packed_a = PackedPrimeField::<F>::from_arr(a_arr);
|
||||
let packed_b = PackedPrimeField::<F>::from_arr(b_arr);
|
||||
let packed_a = Avx2PrimeField::<F>::from_arr(a_arr);
|
||||
let packed_b = Avx2PrimeField::<F>::from_arr(b_arr);
|
||||
let packed_res = packed_a + packed_b;
|
||||
let arr_res = packed_res.to_arr();
|
||||
let arr_res = packed_res.as_arr();
|
||||
|
||||
let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a + b);
|
||||
for (exp, res) in expected.zip(arr_res) {
|
||||
@ -50,17 +50,17 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn test_mul<F: ReducibleAVX2>()
|
||||
fn test_mul<F: ReducibleAvx2>()
|
||||
where
|
||||
[(); PackedPrimeField::<F>::WIDTH]: ,
|
||||
[(); Avx2PrimeField::<F>::WIDTH]:,
|
||||
{
|
||||
let a_arr = test_vals_a::<F>();
|
||||
let b_arr = test_vals_b::<F>();
|
||||
|
||||
let packed_a = PackedPrimeField::<F>::from_arr(a_arr);
|
||||
let packed_b = PackedPrimeField::<F>::from_arr(b_arr);
|
||||
let packed_a = Avx2PrimeField::<F>::from_arr(a_arr);
|
||||
let packed_b = Avx2PrimeField::<F>::from_arr(b_arr);
|
||||
let packed_res = packed_a * packed_b;
|
||||
let arr_res = packed_res.to_arr();
|
||||
let arr_res = packed_res.as_arr();
|
||||
|
||||
let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a * b);
|
||||
for (exp, res) in expected.zip(arr_res) {
|
||||
@ -68,15 +68,15 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn test_square<F: ReducibleAVX2>()
|
||||
fn test_square<F: ReducibleAvx2>()
|
||||
where
|
||||
[(); PackedPrimeField::<F>::WIDTH]: ,
|
||||
[(); Avx2PrimeField::<F>::WIDTH]:,
|
||||
{
|
||||
let a_arr = test_vals_a::<F>();
|
||||
|
||||
let packed_a = PackedPrimeField::<F>::from_arr(a_arr);
|
||||
let packed_a = Avx2PrimeField::<F>::from_arr(a_arr);
|
||||
let packed_res = packed_a.square();
|
||||
let arr_res = packed_res.to_arr();
|
||||
let arr_res = packed_res.as_arr();
|
||||
|
||||
let expected = a_arr.iter().map(|&a| a.square());
|
||||
for (exp, res) in expected.zip(arr_res) {
|
||||
@ -84,15 +84,15 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn test_neg<F: ReducibleAVX2>()
|
||||
fn test_neg<F: ReducibleAvx2>()
|
||||
where
|
||||
[(); PackedPrimeField::<F>::WIDTH]: ,
|
||||
[(); Avx2PrimeField::<F>::WIDTH]:,
|
||||
{
|
||||
let a_arr = test_vals_a::<F>();
|
||||
|
||||
let packed_a = PackedPrimeField::<F>::from_arr(a_arr);
|
||||
let packed_a = Avx2PrimeField::<F>::from_arr(a_arr);
|
||||
let packed_res = -packed_a;
|
||||
let arr_res = packed_res.to_arr();
|
||||
let arr_res = packed_res.as_arr();
|
||||
|
||||
let expected = a_arr.iter().map(|&a| -a);
|
||||
for (exp, res) in expected.zip(arr_res) {
|
||||
@ -100,17 +100,17 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn test_sub<F: ReducibleAVX2>()
|
||||
fn test_sub<F: ReducibleAvx2>()
|
||||
where
|
||||
[(); PackedPrimeField::<F>::WIDTH]: ,
|
||||
[(); Avx2PrimeField::<F>::WIDTH]:,
|
||||
{
|
||||
let a_arr = test_vals_a::<F>();
|
||||
let b_arr = test_vals_b::<F>();
|
||||
|
||||
let packed_a = PackedPrimeField::<F>::from_arr(a_arr);
|
||||
let packed_b = PackedPrimeField::<F>::from_arr(b_arr);
|
||||
let packed_a = Avx2PrimeField::<F>::from_arr(a_arr);
|
||||
let packed_b = Avx2PrimeField::<F>::from_arr(b_arr);
|
||||
let packed_res = packed_a - packed_b;
|
||||
let arr_res = packed_res.to_arr();
|
||||
let arr_res = packed_res.as_arr();
|
||||
|
||||
let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a - b);
|
||||
for (exp, res) in expected.zip(arr_res) {
|
||||
@ -118,33 +118,39 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn test_interleave_is_involution<F: ReducibleAVX2>()
|
||||
fn test_interleave_is_involution<F: ReducibleAvx2>()
|
||||
where
|
||||
[(); PackedPrimeField::<F>::WIDTH]: ,
|
||||
[(); Avx2PrimeField::<F>::WIDTH]:,
|
||||
{
|
||||
let a_arr = test_vals_a::<F>();
|
||||
let b_arr = test_vals_b::<F>();
|
||||
|
||||
let packed_a = PackedPrimeField::<F>::from_arr(a_arr);
|
||||
let packed_b = PackedPrimeField::<F>::from_arr(b_arr);
|
||||
let packed_a = Avx2PrimeField::<F>::from_arr(a_arr);
|
||||
let packed_b = Avx2PrimeField::<F>::from_arr(b_arr);
|
||||
{
|
||||
// Interleave, then deinterleave.
|
||||
let (x, y) = packed_a.interleave(packed_b, 0);
|
||||
let (res_a, res_b) = x.interleave(y, 0);
|
||||
assert_eq!(res_a.to_arr(), a_arr);
|
||||
assert_eq!(res_b.to_arr(), b_arr);
|
||||
}
|
||||
{
|
||||
let (x, y) = packed_a.interleave(packed_b, 1);
|
||||
let (res_a, res_b) = x.interleave(y, 1);
|
||||
assert_eq!(res_a.to_arr(), a_arr);
|
||||
assert_eq!(res_b.to_arr(), b_arr);
|
||||
assert_eq!(res_a.as_arr(), a_arr);
|
||||
assert_eq!(res_b.as_arr(), b_arr);
|
||||
}
|
||||
{
|
||||
let (x, y) = packed_a.interleave(packed_b, 2);
|
||||
let (res_a, res_b) = x.interleave(y, 2);
|
||||
assert_eq!(res_a.as_arr(), a_arr);
|
||||
assert_eq!(res_b.as_arr(), b_arr);
|
||||
}
|
||||
{
|
||||
let (x, y) = packed_a.interleave(packed_b, 4);
|
||||
let (res_a, res_b) = x.interleave(y, 4);
|
||||
assert_eq!(res_a.as_arr(), a_arr);
|
||||
assert_eq!(res_b.as_arr(), b_arr);
|
||||
}
|
||||
}
|
||||
|
||||
fn test_interleave<F: ReducibleAVX2>()
|
||||
fn test_interleave<F: ReducibleAvx2>()
|
||||
where
|
||||
[(); PackedPrimeField::<F>::WIDTH]: ,
|
||||
[(); Avx2PrimeField::<F>::WIDTH]:,
|
||||
{
|
||||
let in_a: [F; 4] = [
|
||||
F::from_noncanonical_u64(00),
|
||||
@ -158,42 +164,47 @@ mod tests {
|
||||
F::from_noncanonical_u64(12),
|
||||
F::from_noncanonical_u64(13),
|
||||
];
|
||||
let int0_a: [F; 4] = [
|
||||
let int1_a: [F; 4] = [
|
||||
F::from_noncanonical_u64(00),
|
||||
F::from_noncanonical_u64(10),
|
||||
F::from_noncanonical_u64(02),
|
||||
F::from_noncanonical_u64(12),
|
||||
];
|
||||
let int0_b: [F; 4] = [
|
||||
let int1_b: [F; 4] = [
|
||||
F::from_noncanonical_u64(01),
|
||||
F::from_noncanonical_u64(11),
|
||||
F::from_noncanonical_u64(03),
|
||||
F::from_noncanonical_u64(13),
|
||||
];
|
||||
let int1_a: [F; 4] = [
|
||||
let int2_a: [F; 4] = [
|
||||
F::from_noncanonical_u64(00),
|
||||
F::from_noncanonical_u64(01),
|
||||
F::from_noncanonical_u64(10),
|
||||
F::from_noncanonical_u64(11),
|
||||
];
|
||||
let int1_b: [F; 4] = [
|
||||
let int2_b: [F; 4] = [
|
||||
F::from_noncanonical_u64(02),
|
||||
F::from_noncanonical_u64(03),
|
||||
F::from_noncanonical_u64(12),
|
||||
F::from_noncanonical_u64(13),
|
||||
];
|
||||
|
||||
let packed_a = PackedPrimeField::<F>::from_arr(in_a);
|
||||
let packed_b = PackedPrimeField::<F>::from_arr(in_b);
|
||||
{
|
||||
let (x0, y0) = packed_a.interleave(packed_b, 0);
|
||||
assert_eq!(x0.to_arr(), int0_a);
|
||||
assert_eq!(y0.to_arr(), int0_b);
|
||||
}
|
||||
let packed_a = Avx2PrimeField::<F>::from_arr(in_a);
|
||||
let packed_b = Avx2PrimeField::<F>::from_arr(in_b);
|
||||
{
|
||||
let (x1, y1) = packed_a.interleave(packed_b, 1);
|
||||
assert_eq!(x1.to_arr(), int1_a);
|
||||
assert_eq!(y1.to_arr(), int1_b);
|
||||
assert_eq!(x1.as_arr(), int1_a);
|
||||
assert_eq!(y1.as_arr(), int1_b);
|
||||
}
|
||||
{
|
||||
let (x2, y2) = packed_a.interleave(packed_b, 2);
|
||||
assert_eq!(x2.as_arr(), int2_a);
|
||||
assert_eq!(y2.as_arr(), int2_b);
|
||||
}
|
||||
{
|
||||
let (x4, y4) = packed_a.interleave(packed_b, 4);
|
||||
assert_eq!(x4.as_arr(), in_a);
|
||||
assert_eq!(y4.as_arr(), in_b);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,77 +1,81 @@
|
||||
use std::fmt;
|
||||
use std::fmt::{Debug, Formatter};
|
||||
use std::fmt::Debug;
|
||||
use std::iter::{Product, Sum};
|
||||
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||
use std::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||
use std::slice;
|
||||
|
||||
use crate::field::field_types::Field;
|
||||
|
||||
pub trait PackedField:
|
||||
/// # Safety
|
||||
/// - WIDTH is assumed to be a power of 2.
|
||||
/// - If P implements PackedField then P must be castable to/from [P::Scalar; P::WIDTH] without UB.
|
||||
pub unsafe trait PackedField:
|
||||
'static
|
||||
+ Add<Self, Output = Self>
|
||||
+ Add<Self::FieldType, Output = Self>
|
||||
+ Add<Self::Scalar, Output = Self>
|
||||
+ AddAssign<Self>
|
||||
+ AddAssign<Self::FieldType>
|
||||
+ AddAssign<Self::Scalar>
|
||||
+ Copy
|
||||
+ Debug
|
||||
+ Default
|
||||
// TODO: Implementing Div sounds like a pain so it's a worry for later.
|
||||
+ From<Self::Scalar>
|
||||
// TODO: Implement packed / packed division
|
||||
+ Div<Self::Scalar, Output = Self>
|
||||
+ Mul<Self, Output = Self>
|
||||
+ Mul<Self::FieldType, Output = Self>
|
||||
+ Mul<Self::Scalar, Output = Self>
|
||||
+ MulAssign<Self>
|
||||
+ MulAssign<Self::FieldType>
|
||||
+ MulAssign<Self::Scalar>
|
||||
+ Neg<Output = Self>
|
||||
+ Product
|
||||
+ Send
|
||||
+ Sub<Self, Output = Self>
|
||||
+ Sub<Self::FieldType, Output = Self>
|
||||
+ Sub<Self::Scalar, Output = Self>
|
||||
+ SubAssign<Self>
|
||||
+ SubAssign<Self::FieldType>
|
||||
+ SubAssign<Self::Scalar>
|
||||
+ Sum
|
||||
+ Sync
|
||||
where
|
||||
Self::Scalar: Add<Self, Output = Self>,
|
||||
Self::Scalar: Mul<Self, Output = Self>,
|
||||
Self::Scalar: Sub<Self, Output = Self>,
|
||||
{
|
||||
type FieldType: Field;
|
||||
type Scalar: Field;
|
||||
|
||||
const LOG2_WIDTH: usize;
|
||||
const WIDTH: usize = 1 << Self::LOG2_WIDTH;
|
||||
const WIDTH: usize;
|
||||
const ZERO: Self;
|
||||
const ONE: Self;
|
||||
|
||||
fn square(&self) -> Self {
|
||||
*self * *self
|
||||
}
|
||||
|
||||
fn zero() -> Self {
|
||||
Self::broadcast(Self::FieldType::ZERO)
|
||||
}
|
||||
fn one() -> Self {
|
||||
Self::broadcast(Self::FieldType::ONE)
|
||||
}
|
||||
fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self;
|
||||
fn as_arr(&self) -> [Self::Scalar; Self::WIDTH];
|
||||
|
||||
fn broadcast(x: Self::FieldType) -> Self;
|
||||
fn from_slice(slice: &[Self::Scalar]) -> &Self;
|
||||
fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self;
|
||||
fn as_slice(&self) -> &[Self::Scalar];
|
||||
fn as_slice_mut(&mut self) -> &mut [Self::Scalar];
|
||||
|
||||
fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self;
|
||||
fn to_arr(&self) -> [Self::FieldType; Self::WIDTH];
|
||||
|
||||
fn from_slice(slice: &[Self::FieldType]) -> Self;
|
||||
fn to_vec(&self) -> Vec<Self::FieldType>;
|
||||
|
||||
/// Take interpret two vectors as chunks of (1 << r) elements. Unpack and interleave those
|
||||
/// Take interpret two vectors as chunks of block_len elements. Unpack and interleave those
|
||||
/// chunks. This is best seen with an example. If we have:
|
||||
/// A = [x0, y0, x1, y1],
|
||||
/// B = [x2, y2, x3, y3],
|
||||
/// then
|
||||
/// interleave(A, B, 0) = ([x0, x2, x1, x3], [y0, y2, y1, y3]).
|
||||
/// interleave(A, B, 1) = ([x0, x2, x1, x3], [y0, y2, y1, y3]).
|
||||
/// Pairs that were adjacent in the input are at corresponding positions in the output.
|
||||
/// r lets us set the size of chunks we're interleaving. If we set r = 1, then for
|
||||
/// r lets us set the size of chunks we're interleaving. If we set block_len = 2, then for
|
||||
/// A = [x0, x1, y0, y1],
|
||||
/// B = [x2, x3, y2, y3],
|
||||
/// we obtain
|
||||
/// interleave(A, B, r) = ([x0, x1, x2, x3], [y0, y1, y2, y3]).
|
||||
/// interleave(A, B, block_len) = ([x0, x1, x2, x3], [y0, y1, y2, y3]).
|
||||
/// We can also think about this as stacking the vectors, dividing them into 2x2 matrices, and
|
||||
/// transposing those matrices.
|
||||
/// When r = LOG2_WIDTH, this operation is a no-op. Values of r > LOG2_WIDTH are not
|
||||
/// permitted.
|
||||
fn interleave(&self, other: Self, r: usize) -> (Self, Self);
|
||||
/// When block_len = WIDTH, this operation is a no-op. block_len must divide WIDTH. Since
|
||||
/// WIDTH is specified to be a power of 2, block_len must also be a power of 2. It cannot be 0
|
||||
/// and it cannot be > WIDTH.
|
||||
fn interleave(&self, other: Self, block_len: usize) -> (Self, Self);
|
||||
|
||||
fn pack_slice(buf: &[Self::FieldType]) -> &[Self] {
|
||||
fn pack_slice(buf: &[Self::Scalar]) -> &[Self] {
|
||||
assert!(
|
||||
buf.len() % Self::WIDTH == 0,
|
||||
"Slice length (got {}) must be a multiple of packed field width ({}).",
|
||||
@ -82,7 +86,7 @@ pub trait PackedField:
|
||||
let n = buf.len() / Self::WIDTH;
|
||||
unsafe { std::slice::from_raw_parts(buf_ptr, n) }
|
||||
}
|
||||
fn pack_slice_mut(buf: &mut [Self::FieldType]) -> &mut [Self] {
|
||||
fn pack_slice_mut(buf: &mut [Self::Scalar]) -> &mut [Self] {
|
||||
assert!(
|
||||
buf.len() % Self::WIDTH == 0,
|
||||
"Slice length (got {}) must be a multiple of packed field width ({}).",
|
||||
@ -95,143 +99,41 @@ pub trait PackedField:
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct Singleton<F: Field>(pub F);
|
||||
unsafe impl<F: Field> PackedField for F {
|
||||
type Scalar = Self;
|
||||
|
||||
impl<F: Field> Add<Self> for Singleton<F> {
|
||||
type Output = Self;
|
||||
fn add(self, rhs: Self) -> Self {
|
||||
Self(self.0 + rhs.0)
|
||||
}
|
||||
}
|
||||
impl<F: Field> Add<F> for Singleton<F> {
|
||||
type Output = Self;
|
||||
fn add(self, rhs: F) -> Self {
|
||||
self + Self::broadcast(rhs)
|
||||
}
|
||||
}
|
||||
impl<F: Field> AddAssign<Self> for Singleton<F> {
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
*self = *self + rhs;
|
||||
}
|
||||
}
|
||||
impl<F: Field> AddAssign<F> for Singleton<F> {
|
||||
fn add_assign(&mut self, rhs: F) {
|
||||
*self = *self + rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Debug for Singleton<F> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "({:?})", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Default for Singleton<F> {
|
||||
fn default() -> Self {
|
||||
Self::zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Mul<Self> for Singleton<F> {
|
||||
type Output = Self;
|
||||
fn mul(self, rhs: Self) -> Self {
|
||||
Self(self.0 * rhs.0)
|
||||
}
|
||||
}
|
||||
impl<F: Field> Mul<F> for Singleton<F> {
|
||||
type Output = Self;
|
||||
fn mul(self, rhs: F) -> Self {
|
||||
self * Self::broadcast(rhs)
|
||||
}
|
||||
}
|
||||
impl<F: Field> MulAssign<Self> for Singleton<F> {
|
||||
fn mul_assign(&mut self, rhs: Self) {
|
||||
*self = *self * rhs;
|
||||
}
|
||||
}
|
||||
impl<F: Field> MulAssign<F> for Singleton<F> {
|
||||
fn mul_assign(&mut self, rhs: F) {
|
||||
*self = *self * rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Neg for Singleton<F> {
|
||||
type Output = Self;
|
||||
fn neg(self) -> Self {
|
||||
Self(-self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Product for Singleton<F> {
|
||||
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
|
||||
Self(iter.map(|x| x.0).product())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> PackedField for Singleton<F> {
|
||||
const LOG2_WIDTH: usize = 0;
|
||||
type FieldType = F;
|
||||
|
||||
fn broadcast(x: F) -> Self {
|
||||
Self(x)
|
||||
}
|
||||
|
||||
fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self {
|
||||
Self(arr[0])
|
||||
}
|
||||
|
||||
fn to_arr(&self) -> [Self::FieldType; Self::WIDTH] {
|
||||
[self.0]
|
||||
}
|
||||
|
||||
fn from_slice(slice: &[Self::FieldType]) -> Self {
|
||||
assert!(slice.len() == 1);
|
||||
Self(slice[0])
|
||||
}
|
||||
|
||||
fn to_vec(&self) -> Vec<Self::FieldType> {
|
||||
vec![self.0]
|
||||
}
|
||||
|
||||
fn interleave(&self, other: Self, r: usize) -> (Self, Self) {
|
||||
match r {
|
||||
0 => (*self, other), // This is a no-op whenever r == LOG2_WIDTH.
|
||||
_ => panic!("r cannot be more than LOG2_WIDTH"),
|
||||
}
|
||||
}
|
||||
const WIDTH: usize = 1;
|
||||
const ZERO: Self = <F as Field>::ZERO;
|
||||
const ONE: Self = <F as Field>::ONE;
|
||||
|
||||
fn square(&self) -> Self {
|
||||
Self(self.0.square())
|
||||
<Self as Field>::square(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Sub<Self> for Singleton<F> {
|
||||
type Output = Self;
|
||||
fn sub(self, rhs: Self) -> Self {
|
||||
Self(self.0 - rhs.0)
|
||||
fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self {
|
||||
arr[0]
|
||||
}
|
||||
}
|
||||
impl<F: Field> Sub<F> for Singleton<F> {
|
||||
type Output = Self;
|
||||
fn sub(self, rhs: F) -> Self {
|
||||
self - Self::broadcast(rhs)
|
||||
fn as_arr(&self) -> [Self::Scalar; Self::WIDTH] {
|
||||
[*self]
|
||||
}
|
||||
}
|
||||
impl<F: Field> SubAssign<Self> for Singleton<F> {
|
||||
fn sub_assign(&mut self, rhs: Self) {
|
||||
*self = *self - rhs;
|
||||
}
|
||||
}
|
||||
impl<F: Field> SubAssign<F> for Singleton<F> {
|
||||
fn sub_assign(&mut self, rhs: F) {
|
||||
*self = *self - rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Sum for Singleton<F> {
|
||||
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
|
||||
Self(iter.map(|x| x.0).sum())
|
||||
fn from_slice(slice: &[Self::Scalar]) -> &Self {
|
||||
&slice[0]
|
||||
}
|
||||
fn from_slice_mut(slice: &mut [Self::Scalar]) -> &mut Self {
|
||||
&mut slice[0]
|
||||
}
|
||||
fn as_slice(&self) -> &[Self::Scalar] {
|
||||
slice::from_ref(self)
|
||||
}
|
||||
fn as_slice_mut(&mut self) -> &mut [Self::Scalar] {
|
||||
slice::from_mut(self)
|
||||
}
|
||||
|
||||
fn interleave(&self, other: Self, block_len: usize) -> (Self, Self) {
|
||||
match block_len {
|
||||
1 => (*self, other),
|
||||
_ => panic!("unsupported block length"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -24,7 +24,7 @@ where
|
||||
ExpectedOp: Fn(u64) -> u64,
|
||||
{
|
||||
let inputs = test_inputs(F::ORDER);
|
||||
let expected: Vec<_> = inputs.iter().map(|x| expected_op(x.clone())).collect();
|
||||
let expected: Vec<_> = inputs.iter().map(|&x| expected_op(x)).collect();
|
||||
let output: Vec<_> = inputs
|
||||
.iter()
|
||||
.cloned()
|
||||
@ -144,7 +144,7 @@ macro_rules! test_prime_field_arithmetic {
|
||||
fn inverse_2exp() {
|
||||
type F = $field;
|
||||
|
||||
let v = <F as Field>::PrimeField::TWO_ADICITY;
|
||||
let v = <F as Field>::TWO_ADICITY;
|
||||
|
||||
for e in [0, 1, 2, 3, 4, v - 2, v - 1, v, v + 1, v + 2, 123 * v] {
|
||||
let x = F::TWO.exp_u64(e as u64);
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
use std::convert::TryInto;
|
||||
use std::fmt;
|
||||
use std::fmt::{Debug, Display, Formatter};
|
||||
use std::hash::{Hash, Hasher};
|
||||
@ -12,7 +11,6 @@ use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
|
||||
/// The base field of the secp256k1 elliptic curve.
|
||||
///
|
||||
@ -36,8 +34,80 @@ fn biguint_from_array(arr: [u64; 4]) -> BigUint {
|
||||
])
|
||||
}
|
||||
|
||||
impl Secp256K1Base {
|
||||
fn to_canonical_biguint(&self) -> BigUint {
|
||||
impl Default for Secp256K1Base {
|
||||
fn default() -> Self {
|
||||
Self::ZERO
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Secp256K1Base {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.to_biguint() == other.to_biguint()
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for Secp256K1Base {}
|
||||
|
||||
impl Hash for Secp256K1Base {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.to_biguint().hash(state)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Secp256K1Base {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
Display::fmt(&self.to_biguint(), f)
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for Secp256K1Base {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
Debug::fmt(&self.to_biguint(), f)
|
||||
}
|
||||
}
|
||||
|
||||
impl Field for Secp256K1Base {
|
||||
const ZERO: Self = Self([0; 4]);
|
||||
const ONE: Self = Self([1, 0, 0, 0]);
|
||||
const TWO: Self = Self([2, 0, 0, 0]);
|
||||
const NEG_ONE: Self = Self([
|
||||
0xFFFFFFFEFFFFFC2E,
|
||||
0xFFFFFFFFFFFFFFFF,
|
||||
0xFFFFFFFFFFFFFFFF,
|
||||
0xFFFFFFFFFFFFFFFF,
|
||||
]);
|
||||
|
||||
const TWO_ADICITY: usize = 1;
|
||||
const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY;
|
||||
|
||||
// Sage: `g = GF(p).multiplicative_generator()`
|
||||
const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([5, 0, 0, 0]);
|
||||
|
||||
// Sage: `g_2 = g^((p - 1) / 2)`
|
||||
const POWER_OF_TWO_GENERATOR: Self = Self::NEG_ONE;
|
||||
|
||||
const BITS: usize = 256;
|
||||
|
||||
fn order() -> BigUint {
|
||||
BigUint::from_slice(&[
|
||||
0xFFFFFC2F, 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
|
||||
0xFFFFFFFF,
|
||||
])
|
||||
}
|
||||
fn characteristic() -> BigUint {
|
||||
Self::order()
|
||||
}
|
||||
|
||||
fn try_inverse(&self) -> Option<Self> {
|
||||
if self.is_zero() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Fermat's Little Theorem
|
||||
Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one())))
|
||||
}
|
||||
|
||||
fn to_biguint(&self) -> BigUint {
|
||||
let mut result = biguint_from_array(self.0);
|
||||
if result >= Self::order() {
|
||||
result -= Self::order();
|
||||
@ -55,79 +125,6 @@ impl Secp256K1Base {
|
||||
.expect("error converting to u64 array"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Secp256K1Base {
|
||||
fn default() -> Self {
|
||||
Self::ZERO
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Secp256K1Base {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.to_canonical_biguint() == other.to_canonical_biguint()
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for Secp256K1Base {}
|
||||
|
||||
impl Hash for Secp256K1Base {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.to_canonical_biguint().hash(state)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Secp256K1Base {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
Display::fmt(&self.to_canonical_biguint(), f)
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for Secp256K1Base {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
Debug::fmt(&self.to_canonical_biguint(), f)
|
||||
}
|
||||
}
|
||||
|
||||
impl Field for Secp256K1Base {
|
||||
// TODO: fix
|
||||
type PrimeField = GoldilocksField;
|
||||
|
||||
const ZERO: Self = Self([0; 4]);
|
||||
const ONE: Self = Self([1, 0, 0, 0]);
|
||||
const TWO: Self = Self([2, 0, 0, 0]);
|
||||
const NEG_ONE: Self = Self([
|
||||
0xFFFFFFFEFFFFFC2E,
|
||||
0xFFFFFFFFFFFFFFFF,
|
||||
0xFFFFFFFFFFFFFFFF,
|
||||
0xFFFFFFFFFFFFFFFF,
|
||||
]);
|
||||
|
||||
// TODO: fix
|
||||
const CHARACTERISTIC: u64 = 0;
|
||||
const TWO_ADICITY: usize = 1;
|
||||
|
||||
// Sage: `g = GF(p).multiplicative_generator()`
|
||||
const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([5, 0, 0, 0]);
|
||||
|
||||
// Sage: `g_2 = g^((p - 1) / 2)`
|
||||
const POWER_OF_TWO_GENERATOR: Self = Self::NEG_ONE;
|
||||
|
||||
fn order() -> BigUint {
|
||||
BigUint::from_slice(&[
|
||||
0xFFFFFC2F, 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
|
||||
0xFFFFFFFF,
|
||||
])
|
||||
}
|
||||
|
||||
fn try_inverse(&self) -> Option<Self> {
|
||||
if self.is_zero() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Fermat's Little Theorem
|
||||
Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one())))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_canonical_u64(n: u64) -> Self {
|
||||
@ -157,7 +154,7 @@ impl Neg for Secp256K1Base {
|
||||
if self.is_zero() {
|
||||
Self::ZERO
|
||||
} else {
|
||||
Self::from_biguint(Self::order() - self.to_canonical_biguint())
|
||||
Self::from_biguint(Self::order() - self.to_biguint())
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -167,7 +164,7 @@ impl Add for Secp256K1Base {
|
||||
|
||||
#[inline]
|
||||
fn add(self, rhs: Self) -> Self {
|
||||
let mut result = self.to_canonical_biguint() + rhs.to_canonical_biguint();
|
||||
let mut result = self.to_biguint() + rhs.to_biguint();
|
||||
if result >= Self::order() {
|
||||
result -= Self::order();
|
||||
}
|
||||
@ -210,9 +207,7 @@ impl Mul for Secp256K1Base {
|
||||
|
||||
#[inline]
|
||||
fn mul(self, rhs: Self) -> Self {
|
||||
Self::from_biguint(
|
||||
(self.to_canonical_biguint() * rhs.to_canonical_biguint()).mod_floor(&Self::order()),
|
||||
)
|
||||
Self::from_biguint((self.to_biguint() * rhs.to_biguint()).mod_floor(&Self::order()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -244,3 +239,10 @@ impl DivAssign for Secp256K1Base {
|
||||
*self = *self / rhs;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::test_field_arithmetic;
|
||||
|
||||
test_field_arithmetic!(crate::field::secp256k1_base::Secp256K1Base);
|
||||
}
|
||||
257
src/field/secp256k1_scalar.rs
Normal file
257
src/field/secp256k1_scalar.rs
Normal file
@ -0,0 +1,257 @@
|
||||
use std::convert::TryInto;
|
||||
use std::fmt;
|
||||
use std::fmt::{Debug, Display, Formatter};
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::iter::{Product, Sum};
|
||||
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||
|
||||
use itertools::Itertools;
|
||||
use num::bigint::{BigUint, RandBigInt};
|
||||
use num::{Integer, One};
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::field::field_types::Field;
|
||||
|
||||
/// The base field of the secp256k1 elliptic curve.
|
||||
///
|
||||
/// Its order is
|
||||
/// ```ignore
|
||||
/// P = 0xFFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE BAAEDCE6 AF48A03B BFD25E8C D0364141
|
||||
/// = 115792089237316195423570985008687907852837564279074904382605163141518161494337
|
||||
/// = 2**256 - 432420386565659656852420866394968145599
|
||||
/// ```
|
||||
#[derive(Copy, Clone, Serialize, Deserialize)]
|
||||
pub struct Secp256K1Scalar(pub [u64; 4]);
|
||||
|
||||
fn biguint_from_array(arr: [u64; 4]) -> BigUint {
|
||||
BigUint::from_slice(&[
|
||||
arr[0] as u32,
|
||||
(arr[0] >> 32) as u32,
|
||||
arr[1] as u32,
|
||||
(arr[1] >> 32) as u32,
|
||||
arr[2] as u32,
|
||||
(arr[2] >> 32) as u32,
|
||||
arr[3] as u32,
|
||||
(arr[3] >> 32) as u32,
|
||||
])
|
||||
}
|
||||
|
||||
impl Default for Secp256K1Scalar {
|
||||
fn default() -> Self {
|
||||
Self::ZERO
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Secp256K1Scalar {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.to_biguint() == other.to_biguint()
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for Secp256K1Scalar {}
|
||||
|
||||
impl Hash for Secp256K1Scalar {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.to_biguint().hash(state)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Secp256K1Scalar {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
Display::fmt(&self.to_biguint(), f)
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for Secp256K1Scalar {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
Debug::fmt(&self.to_biguint(), f)
|
||||
}
|
||||
}
|
||||
|
||||
impl Field for Secp256K1Scalar {
|
||||
const ZERO: Self = Self([0; 4]);
|
||||
const ONE: Self = Self([1, 0, 0, 0]);
|
||||
const TWO: Self = Self([2, 0, 0, 0]);
|
||||
const NEG_ONE: Self = Self([
|
||||
0xBFD25E8CD0364140,
|
||||
0xBAAEDCE6AF48A03B,
|
||||
0xFFFFFFFFFFFFFFFE,
|
||||
0xFFFFFFFFFFFFFFFF,
|
||||
]);
|
||||
|
||||
const TWO_ADICITY: usize = 6;
|
||||
const CHARACTERISTIC_TWO_ADICITY: usize = Self::TWO_ADICITY;
|
||||
|
||||
// Sage: `g = GF(p).multiplicative_generator()`
|
||||
const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([7, 0, 0, 0]);
|
||||
|
||||
// Sage: `g_2 = power_mod(g, (p - 1) // 2^6), p)`
|
||||
// 5480320495727936603795231718619559942670027629901634955707709633242980176626
|
||||
const POWER_OF_TWO_GENERATOR: Self = Self([
|
||||
0x992f4b5402b052f2,
|
||||
0x98BDEAB680756045,
|
||||
0xDF9879A3FBC483A8,
|
||||
0xC1DC060E7A91986,
|
||||
]);
|
||||
|
||||
const BITS: usize = 256;
|
||||
|
||||
fn order() -> BigUint {
|
||||
BigUint::from_slice(&[
|
||||
0xD0364141, 0xBFD25E8C, 0xAF48A03B, 0xBAAEDCE6, 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFFFF,
|
||||
0xFFFFFFFF,
|
||||
])
|
||||
}
|
||||
fn characteristic() -> BigUint {
|
||||
Self::order()
|
||||
}
|
||||
|
||||
fn try_inverse(&self) -> Option<Self> {
|
||||
if self.is_zero() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Fermat's Little Theorem
|
||||
Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one())))
|
||||
}
|
||||
|
||||
fn to_biguint(&self) -> BigUint {
|
||||
let mut result = biguint_from_array(self.0);
|
||||
if result >= Self::order() {
|
||||
result -= Self::order();
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn from_biguint(val: BigUint) -> Self {
|
||||
Self(
|
||||
val.to_u64_digits()
|
||||
.into_iter()
|
||||
.pad_using(4, |_| 0)
|
||||
.collect::<Vec<_>>()[..]
|
||||
.try_into()
|
||||
.expect("error converting to u64 array"),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_canonical_u64(n: u64) -> Self {
|
||||
Self([n, 0, 0, 0])
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_noncanonical_u128(n: u128) -> Self {
|
||||
Self([n as u64, (n >> 64) as u64, 0, 0])
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_noncanonical_u96(n: (u64, u32)) -> Self {
|
||||
Self([n.0, n.1 as u64, 0, 0])
|
||||
}
|
||||
|
||||
fn rand_from_rng<R: Rng>(rng: &mut R) -> Self {
|
||||
Self::from_biguint(rng.gen_biguint_below(&Self::order()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Neg for Secp256K1Scalar {
|
||||
type Output = Self;
|
||||
|
||||
#[inline]
|
||||
fn neg(self) -> Self {
|
||||
if self.is_zero() {
|
||||
Self::ZERO
|
||||
} else {
|
||||
Self::from_biguint(Self::order() - self.to_biguint())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Add for Secp256K1Scalar {
|
||||
type Output = Self;
|
||||
|
||||
#[inline]
|
||||
fn add(self, rhs: Self) -> Self {
|
||||
let mut result = self.to_biguint() + rhs.to_biguint();
|
||||
if result >= Self::order() {
|
||||
result -= Self::order();
|
||||
}
|
||||
Self::from_biguint(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl AddAssign for Secp256K1Scalar {
|
||||
#[inline]
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
*self = *self + rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Sum for Secp256K1Scalar {
|
||||
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
|
||||
iter.fold(Self::ZERO, |acc, x| acc + x)
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub for Secp256K1Scalar {
|
||||
type Output = Self;
|
||||
|
||||
#[inline]
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn sub(self, rhs: Self) -> Self {
|
||||
self + -rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl SubAssign for Secp256K1Scalar {
|
||||
#[inline]
|
||||
fn sub_assign(&mut self, rhs: Self) {
|
||||
*self = *self - rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul for Secp256K1Scalar {
|
||||
type Output = Self;
|
||||
|
||||
#[inline]
|
||||
fn mul(self, rhs: Self) -> Self {
|
||||
Self::from_biguint((self.to_biguint() * rhs.to_biguint()).mod_floor(&Self::order()))
|
||||
}
|
||||
}
|
||||
|
||||
impl MulAssign for Secp256K1Scalar {
|
||||
#[inline]
|
||||
fn mul_assign(&mut self, rhs: Self) {
|
||||
*self = *self * rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Product for Secp256K1Scalar {
|
||||
#[inline]
|
||||
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
|
||||
iter.reduce(|acc, x| acc * x).unwrap_or(Self::ONE)
|
||||
}
|
||||
}
|
||||
|
||||
impl Div for Secp256K1Scalar {
|
||||
type Output = Self;
|
||||
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn div(self, rhs: Self) -> Self::Output {
|
||||
self * rhs.inverse()
|
||||
}
|
||||
}
|
||||
|
||||
impl DivAssign for Secp256K1Scalar {
|
||||
fn div_assign(&mut self, rhs: Self) {
|
||||
*self = *self / rhs;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::test_field_arithmetic;
|
||||
|
||||
test_field_arithmetic!(crate::field::secp256k1_scalar::Secp256K1Scalar);
|
||||
}
|
||||
@ -11,14 +11,14 @@ use crate::plonk::circuit_data::CommonCircuitData;
|
||||
use crate::plonk::config::GenericConfig;
|
||||
use crate::plonk::plonk_common::PlonkPolynomials;
|
||||
use crate::plonk::proof::OpeningSet;
|
||||
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
|
||||
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
|
||||
use crate::timed;
|
||||
use crate::util::reducing::ReducingFactor;
|
||||
use crate::util::timing::TimingTree;
|
||||
use crate::util::{log2_strict, reverse_bits, reverse_index_bits_in_place, transpose};
|
||||
|
||||
/// Two (~64 bit) field elements gives ~128 bit security.
|
||||
pub const SALT_SIZE: usize = 2;
|
||||
/// Four (~64 bit) field elements gives ~128 bit security.
|
||||
pub const SALT_SIZE: usize = 4;
|
||||
|
||||
/// Represents a batch FRI based commitment to a list of polynomials.
|
||||
pub struct PolynomialBatchCommitment<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize> {
|
||||
|
||||
@ -36,8 +36,4 @@ impl FriParams {
|
||||
pub(crate) fn max_arity_bits(&self) -> Option<usize> {
|
||||
self.reduction_arity_bits.iter().copied().max()
|
||||
}
|
||||
|
||||
pub(crate) fn max_arity(&self) -> Option<usize> {
|
||||
self.max_arity_bits().map(|bits| 1 << bits)
|
||||
}
|
||||
}
|
||||
|
||||
@ -16,7 +16,7 @@ use crate::plonk::circuit_data::CommonCircuitData;
|
||||
use crate::plonk::config::{GenericConfig, Hasher};
|
||||
use crate::plonk::plonk_common::PolynomialsIndexBlinding;
|
||||
use crate::plonk::proof::{FriInferredElements, ProofChallenges};
|
||||
use crate::polynomial::polynomial::PolynomialCoeffs;
|
||||
use crate::polynomial::PolynomialCoeffs;
|
||||
|
||||
/// Evaluations and Merkle proof produced by the prover in a FRI query step.
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
|
||||
|
||||
@ -9,7 +9,7 @@ use crate::iop::challenger::Challenger;
|
||||
use crate::plonk::circuit_data::CommonCircuitData;
|
||||
use crate::plonk::config::{GenericConfig, Hasher};
|
||||
use crate::plonk::plonk_common::reduce_with_powers;
|
||||
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
|
||||
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
|
||||
use crate::timed;
|
||||
use crate::util::reverse_index_bits_in_place;
|
||||
use crate::util::timing::TimingTree;
|
||||
|
||||
@ -3,13 +3,16 @@ use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::fri::proof::{FriInitialTreeProofTarget, FriProofTarget, FriQueryRoundTarget};
|
||||
use crate::fri::FriConfig;
|
||||
use crate::gadgets::interpolation::InterpolationGate;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::interpolation::InterpolationGate;
|
||||
use crate::gates::interpolation::HighDegreeInterpolationGate;
|
||||
use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate;
|
||||
use crate::gates::random_access::RandomAccessGate;
|
||||
use crate::hash::hash_types::MerkleCapTarget;
|
||||
use crate::iop::challenger::RecursiveChallenger;
|
||||
use crate::iop::target::{BoolTarget, Target};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData};
|
||||
use crate::plonk::circuit_data::CommonCircuitData;
|
||||
use crate::plonk::config::{AlgebraicConfig, AlgebraicHasher, GenericConfig};
|
||||
use crate::plonk::plonk_common::PlonkPolynomials;
|
||||
@ -28,6 +31,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
arity_bits: usize,
|
||||
evals: &[ExtensionTarget<D>],
|
||||
beta: ExtensionTarget<D>,
|
||||
common_data: &CommonCircuitData<F, D>,
|
||||
) -> ExtensionTarget<D> {
|
||||
let arity = 1 << arity_bits;
|
||||
debug_assert_eq!(evals.len(), arity);
|
||||
@ -44,37 +48,62 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
let coset_start = self.mul(start, x);
|
||||
|
||||
// The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta.
|
||||
self.interpolate_coset(arity_bits, coset_start, &evals, beta)
|
||||
// `HighDegreeInterpolationGate` has degree `arity`, so we use the low-degree gate if
|
||||
// the arity is too large.
|
||||
if arity > common_data.quotient_degree_factor {
|
||||
self.interpolate_coset::<LowDegreeInterpolationGate<F, D>>(
|
||||
arity_bits,
|
||||
coset_start,
|
||||
&evals,
|
||||
beta,
|
||||
)
|
||||
} else {
|
||||
self.interpolate_coset::<HighDegreeInterpolationGate<F, D>>(
|
||||
arity_bits,
|
||||
coset_start,
|
||||
&evals,
|
||||
beta,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Make sure we have enough wires and routed wires to do the FRI checks efficiently. This check
|
||||
/// isn't required -- without it we'd get errors elsewhere in the stack -- but just gives more
|
||||
/// helpful errors.
|
||||
fn check_recursion_config(&self, max_fri_arity: usize) {
|
||||
fn check_recursion_config(
|
||||
&self,
|
||||
max_fri_arity_bits: usize,
|
||||
common_data: &CommonCircuitData<F, D>,
|
||||
) {
|
||||
let random_access = RandomAccessGate::<F, D>::new_from_config(
|
||||
&self.config,
|
||||
max_fri_arity.max(1 << self.config.cap_height),
|
||||
max_fri_arity_bits.max(self.config.cap_height),
|
||||
);
|
||||
let interpolation_gate = InterpolationGate::<F, D>::new(log2_strict(max_fri_arity));
|
||||
let (interpolation_wires, interpolation_routed_wires) =
|
||||
if 1 << max_fri_arity_bits > common_data.quotient_degree_factor {
|
||||
let gate = LowDegreeInterpolationGate::<F, D>::new(max_fri_arity_bits);
|
||||
(gate.num_wires(), gate.num_routed_wires())
|
||||
} else {
|
||||
let gate = HighDegreeInterpolationGate::<F, D>::new(max_fri_arity_bits);
|
||||
(gate.num_wires(), gate.num_routed_wires())
|
||||
};
|
||||
|
||||
let min_wires = random_access
|
||||
.num_wires()
|
||||
.max(interpolation_gate.num_wires());
|
||||
let min_wires = random_access.num_wires().max(interpolation_wires);
|
||||
let min_routed_wires = random_access
|
||||
.num_routed_wires()
|
||||
.max(interpolation_gate.num_routed_wires());
|
||||
.max(interpolation_routed_wires);
|
||||
|
||||
assert!(
|
||||
self.config.num_wires >= min_wires,
|
||||
"To efficiently perform FRI checks with an arity of {}, at least {} wires are needed. Consider reducing arity.",
|
||||
max_fri_arity,
|
||||
"To efficiently perform FRI checks with an arity of 2^{}, at least {} wires are needed. Consider reducing arity.",
|
||||
max_fri_arity_bits,
|
||||
min_wires
|
||||
);
|
||||
|
||||
assert!(
|
||||
self.config.num_routed_wires >= min_routed_wires,
|
||||
"To efficiently perform FRI checks with an arity of {}, at least {} routed wires are needed. Consider reducing arity.",
|
||||
max_fri_arity,
|
||||
"To efficiently perform FRI checks with an arity of 2^{}, at least {} routed wires are needed. Consider reducing arity.",
|
||||
max_fri_arity_bits,
|
||||
min_routed_wires
|
||||
);
|
||||
}
|
||||
@ -108,8 +137,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
) {
|
||||
let config = &common_data.config;
|
||||
|
||||
if let Some(max_arity) = common_data.fri_params.max_arity() {
|
||||
self.check_recursion_config(max_arity);
|
||||
if let Some(max_arity_bits) = common_data.fri_params.max_arity_bits() {
|
||||
self.check_recursion_config(max_arity_bits, common_data);
|
||||
}
|
||||
|
||||
debug_assert_eq!(
|
||||
@ -233,7 +262,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
common_data: &CommonCircuitData<F, C, D>,
|
||||
) -> ExtensionTarget<D> {
|
||||
assert!(D > 1, "Not implemented for D=1.");
|
||||
let config = self.config.clone();
|
||||
let config = &common_data.config;
|
||||
let degree_log = common_data.degree_bits;
|
||||
debug_assert_eq!(
|
||||
degree_log,
|
||||
@ -306,9 +335,13 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
common_data: &CommonCircuitData<F, C, D>,
|
||||
) {
|
||||
let n_log = log2_strict(n);
|
||||
// TODO: Do we need to range check `x_index` to a target smaller than `p`?
|
||||
|
||||
// Note that this `low_bits` decomposition permits non-canonical binary encodings. Here we
|
||||
// verify that this has a negligible impact on soundness error.
|
||||
Self::assert_noncanonical_indices_ok(&common_data.config);
|
||||
let x_index = challenger.get_challenge(self);
|
||||
let mut x_index_bits = self.low_bits(x_index, n_log, 64);
|
||||
let mut x_index_bits = self.low_bits(x_index, n_log, F::BITS);
|
||||
|
||||
let cap_index =
|
||||
self.le_sum(x_index_bits[x_index_bits.len() - common_data.config.cap_height..].iter());
|
||||
with_context!(
|
||||
@ -376,6 +409,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
arity_bits,
|
||||
evals,
|
||||
betas[i],
|
||||
common_data
|
||||
)
|
||||
);
|
||||
|
||||
@ -409,6 +443,26 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
);
|
||||
self.connect_extension(eval, old_eval);
|
||||
}
|
||||
|
||||
/// We decompose FRI query indices into bits without verifying that the decomposition given by
|
||||
/// the prover is the canonical one. In particular, if `x_index < 2^field_bits - p`, then the
|
||||
/// prover could supply the binary encoding of either `x_index` or `x_index + p`, since the are
|
||||
/// congruent mod `p`. However, this only occurs with probability
|
||||
/// p_ambiguous = (2^field_bits - p) / p
|
||||
/// which is small for the field that we use in practice.
|
||||
///
|
||||
/// In particular, the soundness error of one FRI query is roughly the codeword rate, which
|
||||
/// is much larger than this ambiguous-element probability given any reasonable parameters.
|
||||
/// Thus ambiguous elements contribute a negligible amount to soundness error.
|
||||
///
|
||||
/// Here we compare the probabilities as a sanity check, to verify the claim above.
|
||||
fn assert_noncanonical_indices_ok(config: &CircuitConfig) {
|
||||
let num_ambiguous_elems = u64::MAX - F::ORDER + 1;
|
||||
let query_error = config.rate();
|
||||
let p_ambiguous = (num_ambiguous_elems as f64) / (F::ORDER as f64);
|
||||
assert!(p_ambiguous < query_error * 1e-5,
|
||||
"A non-negligible portion of field elements are in the range that permits non-canonical encodings. Need to do more analysis or enforce canonical encodings.");
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
|
||||
@ -2,6 +2,8 @@ use std::borrow::Borrow;
|
||||
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::gates::arithmetic::ArithmeticExtensionGate;
|
||||
use crate::field::field_types::{PrimeField, RichField};
|
||||
use crate::gates::arithmetic_base::ArithmeticGate;
|
||||
use crate::gates::exponentiation::ExponentiationGate;
|
||||
use crate::iop::target::{BoolTarget, Target};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
@ -32,18 +34,117 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
multiplicand_1: Target,
|
||||
addend: Target,
|
||||
) -> Target {
|
||||
let multiplicand_0_ext = self.convert_to_ext(multiplicand_0);
|
||||
let multiplicand_1_ext = self.convert_to_ext(multiplicand_1);
|
||||
let addend_ext = self.convert_to_ext(addend);
|
||||
// If we're not configured to use the base arithmetic gate, just call arithmetic_extension.
|
||||
if !self.config.use_base_arithmetic_gate {
|
||||
let multiplicand_0_ext = self.convert_to_ext(multiplicand_0);
|
||||
let multiplicand_1_ext = self.convert_to_ext(multiplicand_1);
|
||||
let addend_ext = self.convert_to_ext(addend);
|
||||
|
||||
self.arithmetic_extension(
|
||||
return self
|
||||
.arithmetic_extension(
|
||||
const_0,
|
||||
const_1,
|
||||
multiplicand_0_ext,
|
||||
multiplicand_1_ext,
|
||||
addend_ext,
|
||||
)
|
||||
.0[0];
|
||||
}
|
||||
|
||||
// See if we can determine the result without adding an `ArithmeticGate`.
|
||||
if let Some(result) =
|
||||
self.arithmetic_special_cases(const_0, const_1, multiplicand_0, multiplicand_1, addend)
|
||||
{
|
||||
return result;
|
||||
}
|
||||
|
||||
// See if we've already computed the same operation.
|
||||
let operation = BaseArithmeticOperation {
|
||||
const_0,
|
||||
const_1,
|
||||
multiplicand_0_ext,
|
||||
multiplicand_1_ext,
|
||||
addend_ext,
|
||||
)
|
||||
.0[0]
|
||||
multiplicand_0,
|
||||
multiplicand_1,
|
||||
addend,
|
||||
};
|
||||
if let Some(&result) = self.base_arithmetic_results.get(&operation) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot.
|
||||
let result = self.add_base_arithmetic_operation(operation);
|
||||
self.base_arithmetic_results.insert(operation, result);
|
||||
result
|
||||
}
|
||||
|
||||
fn add_base_arithmetic_operation(&mut self, operation: BaseArithmeticOperation<F>) -> Target {
|
||||
let (gate, i) = self.find_base_arithmetic_gate(operation.const_0, operation.const_1);
|
||||
let wires_multiplicand_0 = Target::wire(gate, ArithmeticGate::wire_ith_multiplicand_0(i));
|
||||
let wires_multiplicand_1 = Target::wire(gate, ArithmeticGate::wire_ith_multiplicand_1(i));
|
||||
let wires_addend = Target::wire(gate, ArithmeticGate::wire_ith_addend(i));
|
||||
|
||||
self.connect(operation.multiplicand_0, wires_multiplicand_0);
|
||||
self.connect(operation.multiplicand_1, wires_multiplicand_1);
|
||||
self.connect(operation.addend, wires_addend);
|
||||
|
||||
Target::wire(gate, ArithmeticGate::wire_ith_output(i))
|
||||
}
|
||||
|
||||
/// Checks for special cases where the value of
|
||||
/// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`
|
||||
/// can be determined without adding an `ArithmeticGate`.
|
||||
fn arithmetic_special_cases(
|
||||
&mut self,
|
||||
const_0: F,
|
||||
const_1: F,
|
||||
multiplicand_0: Target,
|
||||
multiplicand_1: Target,
|
||||
addend: Target,
|
||||
) -> Option<Target> {
|
||||
let zero = self.zero();
|
||||
|
||||
let mul_0_const = self.target_as_constant(multiplicand_0);
|
||||
let mul_1_const = self.target_as_constant(multiplicand_1);
|
||||
let addend_const = self.target_as_constant(addend);
|
||||
|
||||
let first_term_zero =
|
||||
const_0 == F::ZERO || multiplicand_0 == zero || multiplicand_1 == zero;
|
||||
let second_term_zero = const_1 == F::ZERO || addend == zero;
|
||||
|
||||
// If both terms are constant, return their (constant) sum.
|
||||
let first_term_const = if first_term_zero {
|
||||
Some(F::ZERO)
|
||||
} else if let (Some(x), Some(y)) = (mul_0_const, mul_1_const) {
|
||||
Some(x * y * const_0)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let second_term_const = if second_term_zero {
|
||||
Some(F::ZERO)
|
||||
} else {
|
||||
addend_const.map(|x| x * const_1)
|
||||
};
|
||||
if let (Some(x), Some(y)) = (first_term_const, second_term_const) {
|
||||
return Some(self.constant(x + y));
|
||||
}
|
||||
|
||||
if first_term_zero && const_1.is_one() {
|
||||
return Some(addend);
|
||||
}
|
||||
|
||||
if second_term_zero {
|
||||
if let Some(x) = mul_0_const {
|
||||
if (x * const_0).is_one() {
|
||||
return Some(multiplicand_1);
|
||||
}
|
||||
}
|
||||
if let Some(x) = mul_1_const {
|
||||
if (x * const_0).is_one() {
|
||||
return Some(multiplicand_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Computes `x * y + z`.
|
||||
@ -53,20 +154,20 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
|
||||
/// Computes `x + C`.
|
||||
pub fn add_const(&mut self, x: Target, c: F) -> Target {
|
||||
let one = self.one();
|
||||
self.arithmetic(F::ONE, c, one, x, one)
|
||||
let c = self.constant(c);
|
||||
self.add(x, c)
|
||||
}
|
||||
|
||||
/// Computes `C * x`.
|
||||
pub fn mul_const(&mut self, c: F, x: Target) -> Target {
|
||||
let zero = self.zero();
|
||||
self.mul_const_add(c, x, zero)
|
||||
let c = self.constant(c);
|
||||
self.mul(c, x)
|
||||
}
|
||||
|
||||
/// Computes `C * x + y`.
|
||||
pub fn mul_const_add(&mut self, c: F, x: Target, y: Target) -> Target {
|
||||
let one = self.one();
|
||||
self.arithmetic(c, F::ONE, x, one, y)
|
||||
let c = self.constant(c);
|
||||
self.mul_add(c, x, y)
|
||||
}
|
||||
|
||||
/// Computes `x * y - z`.
|
||||
@ -82,13 +183,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
}
|
||||
|
||||
/// Add `n` `Target`s.
|
||||
// TODO: Can be made `D` times more efficient by using all wires of an `ArithmeticExtensionGate`.
|
||||
pub fn add_many(&mut self, terms: &[Target]) -> Target {
|
||||
let terms_ext = terms
|
||||
.iter()
|
||||
.map(|&t| self.convert_to_ext(t))
|
||||
.collect::<Vec<_>>();
|
||||
self.add_many_extension(&terms_ext).to_target_array()[0]
|
||||
terms.iter().fold(self.zero(), |acc, &t| self.add(acc, t))
|
||||
}
|
||||
|
||||
/// Computes `x - y`.
|
||||
@ -106,16 +202,16 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
|
||||
/// Multiply `n` `Target`s.
|
||||
pub fn mul_many(&mut self, terms: &[Target]) -> Target {
|
||||
let terms_ext = terms
|
||||
terms
|
||||
.iter()
|
||||
.map(|&t| self.convert_to_ext(t))
|
||||
.collect::<Vec<_>>();
|
||||
self.mul_many_extension(&terms_ext).to_target_array()[0]
|
||||
.copied()
|
||||
.reduce(|acc, t| self.mul(acc, t))
|
||||
.unwrap_or_else(|| self.one())
|
||||
}
|
||||
|
||||
/// Exponentiate `base` to the power of `2^power_log`.
|
||||
pub fn exp_power_of_2(&mut self, base: Target, power_log: usize) -> Target {
|
||||
if power_log > ArithmeticExtensionGate::<D>::new_from_config(&self.config).num_ops {
|
||||
if power_log > self.num_base_arithmetic_ops_per_gate() {
|
||||
// Cheaper to just use `ExponentiateGate`.
|
||||
return self.exp_u64(base, 1 << power_log);
|
||||
}
|
||||
@ -169,8 +265,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
let base_t = self.constant(base);
|
||||
let exponent_bits: Vec<_> = exponent_bits.into_iter().map(|b| *b.borrow()).collect();
|
||||
|
||||
if exponent_bits.len() > ArithmeticExtensionGate::<D>::new_from_config(&self.config).num_ops
|
||||
{
|
||||
if exponent_bits.len() > self.num_base_arithmetic_ops_per_gate() {
|
||||
// Cheaper to just use `ExponentiateGate`.
|
||||
return self.exp_from_bits(base_t, exponent_bits);
|
||||
}
|
||||
@ -220,3 +315,13 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
self.inverse_extension(x_ext).0[0]
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a base arithmetic operation in the circuit. Used to memoize results.
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
|
||||
pub(crate) struct BaseArithmeticOperation<F: PrimeField> {
|
||||
const_0: F,
|
||||
const_1: F,
|
||||
multiplicand_0: Target,
|
||||
multiplicand_1: Target,
|
||||
addend: Target,
|
||||
}
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
use std::convert::TryInto;
|
||||
|
||||
use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget};
|
||||
use crate::field::extension_field::FieldExtension;
|
||||
use crate::field::extension_field::{Extendable, OEF};
|
||||
use crate::field::field_types::{Field, PrimeField, RichField};
|
||||
use crate::gates::arithmetic_extension::ArithmeticExtensionGate;
|
||||
use crate::gates::multiplication_extension::MulExtensionGate;
|
||||
use crate::field::field_types::{Field, PrimeField};
|
||||
use crate::gates::arithmetic::ArithmeticExtensionGate;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
|
||||
@ -12,33 +13,6 @@ use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::util::bits_u64;
|
||||
|
||||
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
/// Finds the last available arithmetic gate with the given constants or add one if there aren't any.
|
||||
/// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index
|
||||
/// `g` and the gate's `i`-th operation is available.
|
||||
fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) {
|
||||
let (gate, i) = self
|
||||
.free_arithmetic
|
||||
.get(&(const_0, const_1))
|
||||
.copied()
|
||||
.unwrap_or_else(|| {
|
||||
let gate = self.add_gate(
|
||||
ArithmeticExtensionGate::new_from_config(&self.config),
|
||||
vec![const_0, const_1],
|
||||
);
|
||||
(gate, 0)
|
||||
});
|
||||
|
||||
// Update `free_arithmetic` with new values.
|
||||
if i < ArithmeticExtensionGate::<D>::num_ops(&self.config) - 1 {
|
||||
self.free_arithmetic
|
||||
.insert((const_0, const_1), (gate, i + 1));
|
||||
} else {
|
||||
self.free_arithmetic.remove(&(const_0, const_1));
|
||||
}
|
||||
|
||||
(gate, i)
|
||||
}
|
||||
|
||||
pub fn arithmetic_extension(
|
||||
&mut self,
|
||||
const_0: F,
|
||||
@ -59,7 +33,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
}
|
||||
|
||||
// See if we've already computed the same operation.
|
||||
let operation = ArithmeticOperation {
|
||||
let operation = ExtensionArithmeticOperation {
|
||||
const_0,
|
||||
const_1,
|
||||
multiplicand_0,
|
||||
@ -70,15 +44,21 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
return result;
|
||||
}
|
||||
|
||||
let result = if self.target_as_constant_ext(addend) == Some(F::Extension::ZERO) {
|
||||
// If the addend is zero, we use a multiplication gate.
|
||||
self.compute_mul_extension_operation(operation)
|
||||
} else {
|
||||
// Otherwise, we use an arithmetic gate.
|
||||
self.compute_arithmetic_extension_operation(operation)
|
||||
};
|
||||
// Otherwise, we must actually perform the operation using an ArithmeticExtensionGate slot.
|
||||
let result = self.add_arithmetic_extension_operation(operation);
|
||||
self.arithmetic_results.insert(operation, result);
|
||||
result
|
||||
}
|
||||
|
||||
fn add_arithmetic_extension_operation(
|
||||
fn compute_arithmetic_extension_operation(
|
||||
&mut self,
|
||||
operation: ArithmeticOperation<F, D>,
|
||||
operation: ExtensionArithmeticOperation<F, D>,
|
||||
) -> ExtensionTarget<D> {
|
||||
let (gate, i) = self.find_arithmetic_gate(operation.const_0, operation.const_1);
|
||||
let wires_multiplicand_0 = ExtensionTarget::from_range(
|
||||
@ -99,6 +79,22 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_ith_output(i))
|
||||
}
|
||||
|
||||
fn compute_mul_extension_operation(
|
||||
&mut self,
|
||||
operation: ExtensionArithmeticOperation<F, D>,
|
||||
) -> ExtensionTarget<D> {
|
||||
let (gate, i) = self.find_mul_gate(operation.const_0);
|
||||
let wires_multiplicand_0 =
|
||||
ExtensionTarget::from_range(gate, MulExtensionGate::<D>::wires_ith_multiplicand_0(i));
|
||||
let wires_multiplicand_1 =
|
||||
ExtensionTarget::from_range(gate, MulExtensionGate::<D>::wires_ith_multiplicand_1(i));
|
||||
|
||||
self.connect_extension(operation.multiplicand_0, wires_multiplicand_0);
|
||||
self.connect_extension(operation.multiplicand_1, wires_multiplicand_1);
|
||||
|
||||
ExtensionTarget::from_range(gate, MulExtensionGate::<D>::wires_ith_output(i))
|
||||
}
|
||||
|
||||
/// Checks for special cases where the value of
|
||||
/// `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`
|
||||
/// can be determined without adding an `ArithmeticGate`.
|
||||
@ -302,11 +298,11 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
|
||||
/// Multiply `n` `ExtensionTarget`s.
|
||||
pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> {
|
||||
let mut product = self.one_extension();
|
||||
for &term in terms {
|
||||
product = self.mul_extension(product, term);
|
||||
}
|
||||
product
|
||||
terms
|
||||
.iter()
|
||||
.copied()
|
||||
.reduce(|acc, t| self.mul_extension(acc, t))
|
||||
.unwrap_or_else(|| self.one_extension())
|
||||
}
|
||||
|
||||
/// Like `mul_add`, but for `ExtensionTarget`s.
|
||||
@ -321,14 +317,14 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
|
||||
/// Like `add_const`, but for `ExtensionTarget`s.
|
||||
pub fn add_const_extension(&mut self, x: ExtensionTarget<D>, c: F) -> ExtensionTarget<D> {
|
||||
let one = self.one_extension();
|
||||
self.arithmetic_extension(F::ONE, c, one, x, one)
|
||||
let c = self.constant_extension(c.into());
|
||||
self.add_extension(x, c)
|
||||
}
|
||||
|
||||
/// Like `mul_const`, but for `ExtensionTarget`s.
|
||||
pub fn mul_const_extension(&mut self, c: F, x: ExtensionTarget<D>) -> ExtensionTarget<D> {
|
||||
let zero = self.zero_extension();
|
||||
self.mul_const_add_extension(c, x, zero)
|
||||
let c = self.constant_extension(c.into());
|
||||
self.mul_extension(c, x)
|
||||
}
|
||||
|
||||
/// Like `mul_const_add`, but for `ExtensionTarget`s.
|
||||
@ -338,8 +334,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
x: ExtensionTarget<D>,
|
||||
y: ExtensionTarget<D>,
|
||||
) -> ExtensionTarget<D> {
|
||||
let one = self.one_extension();
|
||||
self.arithmetic_extension(c, F::ONE, x, one, y)
|
||||
let c = self.constant_extension(c.into());
|
||||
self.mul_add_extension(c, x, y)
|
||||
}
|
||||
|
||||
/// Like `mul_add`, but for `ExtensionTarget`s.
|
||||
@ -544,9 +540,9 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents an arithmetic operation in the circuit. Used to memoize results.
|
||||
/// Represents an extension arithmetic operation in the circuit. Used to memoize results.
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
|
||||
pub(crate) struct ArithmeticOperation<F: PrimeField + Extendable<D>, const D: usize> {
|
||||
pub(crate) struct ExtensionArithmeticOperation<F: PrimeField + Extendable<D>, const D: usize> {
|
||||
const_0: F,
|
||||
const_1: F,
|
||||
multiplicand_0: ExtensionTarget<D>,
|
||||
@ -556,11 +552,11 @@ pub(crate) struct ArithmeticOperation<F: PrimeField + Extendable<D>, const D: us
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::convert::TryInto;
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
use crate::field::extension_field::algebra::ExtensionAlgebra;
|
||||
use crate::field::extension_field::quartic::QuarticExtension;
|
||||
use crate::field::extension_field::target::ExtensionAlgebraTarget;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::iop::witness::{PartialWitness, Witness};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
@ -623,9 +619,7 @@ mod tests {
|
||||
let yt = builder.constant_extension(y);
|
||||
let zt = builder.constant_extension(z);
|
||||
let comp_zt = builder.div_extension(xt, yt);
|
||||
let comp_zt_unsafe = builder.div_extension(xt, yt);
|
||||
builder.connect_extension(zt, comp_zt);
|
||||
builder.connect_extension(zt, comp_zt_unsafe);
|
||||
|
||||
let data = builder.build::<C>();
|
||||
let proof = data.prove(pw)?;
|
||||
@ -642,23 +636,29 @@ mod tests {
|
||||
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
|
||||
let pw = PartialWitness::new();
|
||||
let mut pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||
|
||||
let x = FF::rand_vec(D);
|
||||
let y = FF::rand_vec(D);
|
||||
let xa = ExtensionAlgebra(x.try_into().unwrap());
|
||||
let ya = ExtensionAlgebra(y.try_into().unwrap());
|
||||
let za = xa * ya;
|
||||
|
||||
let xt = builder.constant_ext_algebra(xa);
|
||||
let yt = builder.constant_ext_algebra(ya);
|
||||
let zt = builder.constant_ext_algebra(za);
|
||||
let xt =
|
||||
ExtensionAlgebraTarget(builder.add_virtual_extension_targets(D).try_into().unwrap());
|
||||
let yt =
|
||||
ExtensionAlgebraTarget(builder.add_virtual_extension_targets(D).try_into().unwrap());
|
||||
let zt =
|
||||
ExtensionAlgebraTarget(builder.add_virtual_extension_targets(D).try_into().unwrap());
|
||||
let comp_zt = builder.mul_ext_algebra(xt, yt);
|
||||
for i in 0..D {
|
||||
builder.connect_extension(zt.0[i], comp_zt.0[i]);
|
||||
}
|
||||
|
||||
let x = ExtensionAlgebra::<FF, D>(FF::rand_arr());
|
||||
let y = ExtensionAlgebra::<FF, D>(FF::rand_arr());
|
||||
let z = x * y;
|
||||
for i in 0..D {
|
||||
pw.set_extension_target(xt.0[i], x.0[i]);
|
||||
pw.set_extension_target(yt.0[i], y.0[i]);
|
||||
pw.set_extension_target(zt.0[i], z.0[i]);
|
||||
}
|
||||
|
||||
let data = builder.build::<C>();
|
||||
let proof = data.prove(pw)?;
|
||||
|
||||
|
||||
154
src/gadgets/arithmetic_u32.rs
Normal file
154
src/gadgets/arithmetic_u32.rs
Normal file
@ -0,0 +1,154 @@
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::RichField;
|
||||
use crate::gates::arithmetic_u32::U32ArithmeticGate;
|
||||
use crate::gates::subtraction_u32::U32SubtractionGate;
|
||||
use crate::iop::target::Target;
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct U32Target(pub Target);
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
pub fn add_virtual_u32_target(&mut self) -> U32Target {
|
||||
U32Target(self.add_virtual_target())
|
||||
}
|
||||
|
||||
pub fn add_virtual_u32_targets(&mut self, n: usize) -> Vec<U32Target> {
|
||||
self.add_virtual_targets(n)
|
||||
.into_iter()
|
||||
.map(U32Target)
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn zero_u32(&mut self) -> U32Target {
|
||||
U32Target(self.zero())
|
||||
}
|
||||
|
||||
pub fn one_u32(&mut self) -> U32Target {
|
||||
U32Target(self.one())
|
||||
}
|
||||
|
||||
pub fn connect_u32(&mut self, x: U32Target, y: U32Target) {
|
||||
self.connect(x.0, y.0)
|
||||
}
|
||||
|
||||
pub fn assert_zero_u32(&mut self, x: U32Target) {
|
||||
self.assert_zero(x.0)
|
||||
}
|
||||
|
||||
/// Checks for special cases where the value of
|
||||
/// `x * y + z`
|
||||
/// can be determined without adding a `U32ArithmeticGate`.
|
||||
pub fn arithmetic_u32_special_cases(
|
||||
&mut self,
|
||||
x: U32Target,
|
||||
y: U32Target,
|
||||
z: U32Target,
|
||||
) -> Option<(U32Target, U32Target)> {
|
||||
let x_const = self.target_as_constant(x.0);
|
||||
let y_const = self.target_as_constant(y.0);
|
||||
let z_const = self.target_as_constant(z.0);
|
||||
|
||||
// If both terms are constant, return their (constant) sum.
|
||||
let first_term_const = if let (Some(xx), Some(yy)) = (x_const, y_const) {
|
||||
Some(xx * yy)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let (Some(a), Some(b)) = (first_term_const, z_const) {
|
||||
let sum = (a + b).to_canonical_u64();
|
||||
let (low, high) = (sum as u32, (sum >> 32) as u32);
|
||||
return Some((self.constant_u32(low), self.constant_u32(high)));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
// Returns x * y + z.
|
||||
pub fn mul_add_u32(
|
||||
&mut self,
|
||||
x: U32Target,
|
||||
y: U32Target,
|
||||
z: U32Target,
|
||||
) -> (U32Target, U32Target) {
|
||||
if let Some(result) = self.arithmetic_u32_special_cases(x, y, z) {
|
||||
return result;
|
||||
}
|
||||
|
||||
let gate = U32ArithmeticGate::<F, D>::new_from_config(&self.config);
|
||||
let (gate_index, copy) = self.find_u32_arithmetic_gate();
|
||||
|
||||
self.connect(
|
||||
Target::wire(gate_index, gate.wire_ith_multiplicand_0(copy)),
|
||||
x.0,
|
||||
);
|
||||
self.connect(
|
||||
Target::wire(gate_index, gate.wire_ith_multiplicand_1(copy)),
|
||||
y.0,
|
||||
);
|
||||
self.connect(Target::wire(gate_index, gate.wire_ith_addend(copy)), z.0);
|
||||
|
||||
let output_low = U32Target(Target::wire(
|
||||
gate_index,
|
||||
gate.wire_ith_output_low_half(copy),
|
||||
));
|
||||
let output_high = U32Target(Target::wire(
|
||||
gate_index,
|
||||
gate.wire_ith_output_high_half(copy),
|
||||
));
|
||||
|
||||
(output_low, output_high)
|
||||
}
|
||||
|
||||
pub fn add_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) {
|
||||
let one = self.one_u32();
|
||||
self.mul_add_u32(a, one, b)
|
||||
}
|
||||
|
||||
pub fn add_many_u32(&mut self, to_add: &[U32Target]) -> (U32Target, U32Target) {
|
||||
match to_add.len() {
|
||||
0 => (self.zero_u32(), self.zero_u32()),
|
||||
1 => (to_add[0], self.zero_u32()),
|
||||
2 => self.add_u32(to_add[0], to_add[1]),
|
||||
_ => {
|
||||
let (mut low, mut carry) = self.add_u32(to_add[0], to_add[1]);
|
||||
for i in 2..to_add.len() {
|
||||
let (new_low, new_carry) = self.add_u32(to_add[i], low);
|
||||
let (combined_carry, _zero) = self.add_u32(carry, new_carry);
|
||||
low = new_low;
|
||||
carry = combined_carry;
|
||||
}
|
||||
(low, carry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) {
|
||||
let zero = self.zero_u32();
|
||||
self.mul_add_u32(a, b, zero)
|
||||
}
|
||||
|
||||
// Returns x - y - borrow, as a pair (result, borrow), where borrow is 0 or 1 depending on whether borrowing from the next digit is required (iff y + borrow > x).
|
||||
pub fn sub_u32(
|
||||
&mut self,
|
||||
x: U32Target,
|
||||
y: U32Target,
|
||||
borrow: U32Target,
|
||||
) -> (U32Target, U32Target) {
|
||||
let gate = U32SubtractionGate::<F, D>::new_from_config(&self.config);
|
||||
let (gate_index, copy) = self.find_u32_subtraction_gate();
|
||||
|
||||
self.connect(Target::wire(gate_index, gate.wire_ith_input_x(copy)), x.0);
|
||||
self.connect(Target::wire(gate_index, gate.wire_ith_input_y(copy)), y.0);
|
||||
self.connect(
|
||||
Target::wire(gate_index, gate.wire_ith_input_borrow(copy)),
|
||||
borrow.0,
|
||||
);
|
||||
|
||||
let output_result = U32Target(Target::wire(gate_index, gate.wire_ith_output_result(copy)));
|
||||
let output_borrow = U32Target(Target::wire(gate_index, gate.wire_ith_output_borrow(copy)));
|
||||
|
||||
(output_result, output_borrow)
|
||||
}
|
||||
}
|
||||
395
src/gadgets/biguint.rs
Normal file
395
src/gadgets/biguint.rs
Normal file
@ -0,0 +1,395 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use num::{BigUint, Integer};
|
||||
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::RichField;
|
||||
use crate::gadgets::arithmetic_u32::U32Target;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
|
||||
use crate::iop::target::{BoolTarget, Target};
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct BigUintTarget {
|
||||
pub limbs: Vec<U32Target>,
|
||||
}
|
||||
|
||||
impl BigUintTarget {
|
||||
pub fn num_limbs(&self) -> usize {
|
||||
self.limbs.len()
|
||||
}
|
||||
|
||||
pub fn get_limb(&self, i: usize) -> U32Target {
|
||||
self.limbs[i]
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
pub fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget {
|
||||
let limb_values = value.to_u32_digits();
|
||||
let limbs = limb_values.iter().map(|&l| self.constant_u32(l)).collect();
|
||||
|
||||
BigUintTarget { limbs }
|
||||
}
|
||||
|
||||
pub fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) {
|
||||
let min_limbs = lhs.num_limbs().min(rhs.num_limbs());
|
||||
for i in 0..min_limbs {
|
||||
self.connect_u32(lhs.get_limb(i), rhs.get_limb(i));
|
||||
}
|
||||
|
||||
for i in min_limbs..lhs.num_limbs() {
|
||||
self.assert_zero_u32(lhs.get_limb(i));
|
||||
}
|
||||
for i in min_limbs..rhs.num_limbs() {
|
||||
self.assert_zero_u32(rhs.get_limb(i));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn pad_biguints(
|
||||
&mut self,
|
||||
a: &BigUintTarget,
|
||||
b: &BigUintTarget,
|
||||
) -> (BigUintTarget, BigUintTarget) {
|
||||
if a.num_limbs() > b.num_limbs() {
|
||||
let mut padded_b = b.clone();
|
||||
for _ in b.num_limbs()..a.num_limbs() {
|
||||
padded_b.limbs.push(self.zero_u32());
|
||||
}
|
||||
|
||||
(a.clone(), padded_b)
|
||||
} else {
|
||||
let mut padded_a = a.clone();
|
||||
for _ in a.num_limbs()..b.num_limbs() {
|
||||
padded_a.limbs.push(self.zero_u32());
|
||||
}
|
||||
|
||||
(padded_a, b.clone())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget {
|
||||
let (a, b) = self.pad_biguints(a, b);
|
||||
|
||||
self.list_le_u32(a.limbs, b.limbs)
|
||||
}
|
||||
|
||||
pub fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget {
|
||||
let limbs = (0..num_limbs)
|
||||
.map(|_| self.add_virtual_u32_target())
|
||||
.collect();
|
||||
|
||||
BigUintTarget { limbs }
|
||||
}
|
||||
|
||||
// Add two `BigUintTarget`s.
|
||||
pub fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
|
||||
let num_limbs = a.num_limbs().max(b.num_limbs());
|
||||
|
||||
let mut combined_limbs = vec![];
|
||||
let mut carry = self.zero_u32();
|
||||
for i in 0..num_limbs {
|
||||
let a_limb = (i < a.num_limbs())
|
||||
.then(|| a.limbs[i])
|
||||
.unwrap_or_else(|| self.zero_u32());
|
||||
let b_limb = (i < b.num_limbs())
|
||||
.then(|| b.limbs[i])
|
||||
.unwrap_or_else(|| self.zero_u32());
|
||||
|
||||
let (new_limb, new_carry) = self.add_many_u32(&[carry, a_limb, b_limb]);
|
||||
carry = new_carry;
|
||||
combined_limbs.push(new_limb);
|
||||
}
|
||||
combined_limbs.push(carry);
|
||||
|
||||
BigUintTarget {
|
||||
limbs: combined_limbs,
|
||||
}
|
||||
}
|
||||
|
||||
// Subtract two `BigUintTarget`s. We assume that the first is larger than the second.
|
||||
pub fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
|
||||
let (a, b) = self.pad_biguints(a, b);
|
||||
let num_limbs = a.limbs.len();
|
||||
|
||||
let mut result_limbs = vec![];
|
||||
|
||||
let mut borrow = self.zero_u32();
|
||||
for i in 0..num_limbs {
|
||||
let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow);
|
||||
result_limbs.push(result);
|
||||
borrow = new_borrow;
|
||||
}
|
||||
// Borrow should be zero here.
|
||||
|
||||
BigUintTarget {
|
||||
limbs: result_limbs,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
|
||||
let total_limbs = a.limbs.len() + b.limbs.len();
|
||||
|
||||
let mut to_add = vec![vec![]; total_limbs];
|
||||
for i in 0..a.limbs.len() {
|
||||
for j in 0..b.limbs.len() {
|
||||
let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]);
|
||||
to_add[i + j].push(product);
|
||||
to_add[i + j + 1].push(carry);
|
||||
}
|
||||
}
|
||||
|
||||
let mut combined_limbs = vec![];
|
||||
let mut carry = self.zero_u32();
|
||||
for summands in &mut to_add {
|
||||
summands.push(carry);
|
||||
let (new_result, new_carry) = self.add_many_u32(summands);
|
||||
combined_limbs.push(new_result);
|
||||
carry = new_carry;
|
||||
}
|
||||
combined_limbs.push(carry);
|
||||
|
||||
BigUintTarget {
|
||||
limbs: combined_limbs,
|
||||
}
|
||||
}
|
||||
|
||||
// Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function).
|
||||
pub fn mul_add_biguint(
|
||||
&mut self,
|
||||
x: &BigUintTarget,
|
||||
y: &BigUintTarget,
|
||||
z: &BigUintTarget,
|
||||
) -> BigUintTarget {
|
||||
let prod = self.mul_biguint(x, y);
|
||||
self.add_biguint(&prod, z)
|
||||
}
|
||||
|
||||
pub fn div_rem_biguint(
|
||||
&mut self,
|
||||
a: &BigUintTarget,
|
||||
b: &BigUintTarget,
|
||||
) -> (BigUintTarget, BigUintTarget) {
|
||||
let a_len = a.limbs.len();
|
||||
let b_len = b.limbs.len();
|
||||
let div_num_limbs = if b_len > a_len + 1 {
|
||||
0
|
||||
} else {
|
||||
a_len - b_len + 1
|
||||
};
|
||||
let div = self.add_virtual_biguint_target(div_num_limbs);
|
||||
let rem = self.add_virtual_biguint_target(b_len);
|
||||
|
||||
self.add_simple_generator(BigUintDivRemGenerator::<F, D> {
|
||||
a: a.clone(),
|
||||
b: b.clone(),
|
||||
div: div.clone(),
|
||||
rem: rem.clone(),
|
||||
_phantom: PhantomData,
|
||||
});
|
||||
|
||||
let div_b = self.mul_biguint(&div, b);
|
||||
let div_b_plus_rem = self.add_biguint(&div_b, &rem);
|
||||
self.connect_biguint(a, &div_b_plus_rem);
|
||||
|
||||
let cmp_rem_b = self.cmp_biguint(&rem, b);
|
||||
self.assert_one(cmp_rem_b.target);
|
||||
|
||||
(div, rem)
|
||||
}
|
||||
|
||||
pub fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
|
||||
let (div, _rem) = self.div_rem_biguint(a, b);
|
||||
div
|
||||
}
|
||||
|
||||
pub fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
|
||||
let (_div, rem) = self.div_rem_biguint(a, b);
|
||||
rem
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct BigUintDivRemGenerator<F: RichField + Extendable<D>, const D: usize> {
|
||||
a: BigUintTarget,
|
||||
b: BigUintTarget,
|
||||
div: BigUintTarget,
|
||||
rem: BigUintTarget,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
|
||||
for BigUintDivRemGenerator<F, D>
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
self.a
|
||||
.limbs
|
||||
.iter()
|
||||
.chain(&self.b.limbs)
|
||||
.map(|&l| l.0)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
let a = witness.get_biguint_target(self.a.clone());
|
||||
let b = witness.get_biguint_target(self.b.clone());
|
||||
let (div, rem) = a.div_rem(&b);
|
||||
|
||||
out_buffer.set_biguint_target(self.div.clone(), div);
|
||||
out_buffer.set_biguint_target(self.rem.clone(), rem);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use num::{BigUint, FromPrimitive, Integer};
|
||||
use rand::Rng;
|
||||
|
||||
use crate::iop::witness::Witness;
|
||||
use crate::{
|
||||
field::goldilocks_field::GoldilocksField,
|
||||
iop::witness::PartialWitness,
|
||||
plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, verifier::verify},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_biguint_add() -> Result<()> {
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let x_value = BigUint::from_u128(rng.gen()).unwrap();
|
||||
let y_value = BigUint::from_u128(rng.gen()).unwrap();
|
||||
let expected_z_value = &x_value + &y_value;
|
||||
|
||||
type F = GoldilocksField;
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
let mut pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, 4>::new(config);
|
||||
|
||||
let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len());
|
||||
let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len());
|
||||
let z = builder.add_biguint(&x, &y);
|
||||
let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len());
|
||||
builder.connect_biguint(&z, &expected_z);
|
||||
|
||||
pw.set_biguint_target(&x, &x_value);
|
||||
pw.set_biguint_target(&y, &y_value);
|
||||
pw.set_biguint_target(&expected_z, &expected_z_value);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_biguint_sub() -> Result<()> {
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let mut x_value = BigUint::from_u128(rng.gen()).unwrap();
|
||||
let mut y_value = BigUint::from_u128(rng.gen()).unwrap();
|
||||
if y_value > x_value {
|
||||
(x_value, y_value) = (y_value, x_value);
|
||||
}
|
||||
let expected_z_value = &x_value - &y_value;
|
||||
|
||||
type F = GoldilocksField;
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, 4>::new(config);
|
||||
|
||||
let x = builder.constant_biguint(&x_value);
|
||||
let y = builder.constant_biguint(&y_value);
|
||||
let z = builder.sub_biguint(&x, &y);
|
||||
let expected_z = builder.constant_biguint(&expected_z_value);
|
||||
|
||||
builder.connect_biguint(&z, &expected_z);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_biguint_mul() -> Result<()> {
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let x_value = BigUint::from_u128(rng.gen()).unwrap();
|
||||
let y_value = BigUint::from_u128(rng.gen()).unwrap();
|
||||
let expected_z_value = &x_value * &y_value;
|
||||
|
||||
type F = GoldilocksField;
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
let mut pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, 4>::new(config);
|
||||
|
||||
let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len());
|
||||
let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len());
|
||||
let z = builder.mul_biguint(&x, &y);
|
||||
let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len());
|
||||
builder.connect_biguint(&z, &expected_z);
|
||||
|
||||
pw.set_biguint_target(&x, &x_value);
|
||||
pw.set_biguint_target(&y, &y_value);
|
||||
pw.set_biguint_target(&expected_z, &expected_z_value);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_biguint_cmp() -> Result<()> {
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let x_value = BigUint::from_u128(rng.gen()).unwrap();
|
||||
let y_value = BigUint::from_u128(rng.gen()).unwrap();
|
||||
|
||||
type F = GoldilocksField;
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, 4>::new(config);
|
||||
|
||||
let x = builder.constant_biguint(&x_value);
|
||||
let y = builder.constant_biguint(&y_value);
|
||||
let cmp = builder.cmp_biguint(&x, &y);
|
||||
let expected_cmp = builder.constant_bool(x_value <= y_value);
|
||||
|
||||
builder.connect(cmp.target, expected_cmp.target);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_biguint_div_rem() -> Result<()> {
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let mut x_value = BigUint::from_u128(rng.gen()).unwrap();
|
||||
let mut y_value = BigUint::from_u128(rng.gen()).unwrap();
|
||||
if y_value > x_value {
|
||||
(x_value, y_value) = (y_value, x_value);
|
||||
}
|
||||
let (expected_div_value, expected_rem_value) = x_value.div_rem(&y_value);
|
||||
|
||||
type F = GoldilocksField;
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, 4>::new(config);
|
||||
|
||||
let x = builder.constant_biguint(&x_value);
|
||||
let y = builder.constant_biguint(&y_value);
|
||||
let (div, rem) = builder.div_rem_biguint(&x, &y);
|
||||
|
||||
let expected_div = builder.constant_biguint(&expected_div_value);
|
||||
let expected_rem = builder.constant_biguint(&expected_rem_value);
|
||||
|
||||
builder.connect_biguint(&div, &expected_div);
|
||||
builder.connect_biguint(&rem, &expected_rem);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
}
|
||||
368
src/gadgets/curve.rs
Normal file
368
src/gadgets/curve.rs
Normal file
@ -0,0 +1,368 @@
|
||||
use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar};
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::gadgets::nonnative::NonNativeTarget;
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
|
||||
/// A Target representing an affine point on the curve `C`. We use incomplete arithmetic for efficiency,
|
||||
/// so we assume these points are not zero.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AffinePointTarget<C: Curve> {
|
||||
pub x: NonNativeTarget<C::BaseField>,
|
||||
pub y: NonNativeTarget<C::BaseField>,
|
||||
}
|
||||
|
||||
impl<C: Curve> AffinePointTarget<C> {
|
||||
pub fn to_vec(&self) -> Vec<NonNativeTarget<C::BaseField>> {
|
||||
vec![self.x.clone(), self.y.clone()]
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
pub fn constant_affine_point<C: Curve>(
|
||||
&mut self,
|
||||
point: AffinePoint<C>,
|
||||
) -> AffinePointTarget<C> {
|
||||
debug_assert!(!point.zero);
|
||||
AffinePointTarget {
|
||||
x: self.constant_nonnative(point.x),
|
||||
y: self.constant_nonnative(point.y),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn connect_affine_point<C: Curve>(
|
||||
&mut self,
|
||||
lhs: &AffinePointTarget<C>,
|
||||
rhs: &AffinePointTarget<C>,
|
||||
) {
|
||||
self.connect_nonnative(&lhs.x, &rhs.x);
|
||||
self.connect_nonnative(&lhs.y, &rhs.y);
|
||||
}
|
||||
|
||||
pub fn add_virtual_affine_point_target<C: Curve>(&mut self) -> AffinePointTarget<C> {
|
||||
let x = self.add_virtual_nonnative_target();
|
||||
let y = self.add_virtual_nonnative_target();
|
||||
|
||||
AffinePointTarget { x, y }
|
||||
}
|
||||
|
||||
pub fn curve_assert_valid<C: Curve>(&mut self, p: &AffinePointTarget<C>) {
|
||||
let a = self.constant_nonnative(C::A);
|
||||
let b = self.constant_nonnative(C::B);
|
||||
|
||||
let y_squared = self.mul_nonnative(&p.y, &p.y);
|
||||
let x_squared = self.mul_nonnative(&p.x, &p.x);
|
||||
let x_cubed = self.mul_nonnative(&x_squared, &p.x);
|
||||
let a_x = self.mul_nonnative(&a, &p.x);
|
||||
let a_x_plus_b = self.add_nonnative(&a_x, &b);
|
||||
let rhs = self.add_nonnative(&x_cubed, &a_x_plus_b);
|
||||
|
||||
self.connect_nonnative(&y_squared, &rhs);
|
||||
}
|
||||
|
||||
pub fn curve_neg<C: Curve>(&mut self, p: &AffinePointTarget<C>) -> AffinePointTarget<C> {
|
||||
let neg_y = self.neg_nonnative(&p.y);
|
||||
AffinePointTarget {
|
||||
x: p.x.clone(),
|
||||
y: neg_y,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn curve_double<C: Curve>(&mut self, p: &AffinePointTarget<C>) -> AffinePointTarget<C> {
|
||||
let AffinePointTarget { x, y } = p;
|
||||
let double_y = self.add_nonnative(y, y);
|
||||
let inv_double_y = self.inv_nonnative(&double_y);
|
||||
let x_squared = self.mul_nonnative(x, x);
|
||||
let double_x_squared = self.add_nonnative(&x_squared, &x_squared);
|
||||
let triple_x_squared = self.add_nonnative(&double_x_squared, &x_squared);
|
||||
|
||||
let a = self.constant_nonnative(C::A);
|
||||
let triple_xx_a = self.add_nonnative(&triple_x_squared, &a);
|
||||
let lambda = self.mul_nonnative(&triple_xx_a, &inv_double_y);
|
||||
let lambda_squared = self.mul_nonnative(&lambda, &lambda);
|
||||
let x_double = self.add_nonnative(x, x);
|
||||
|
||||
let x3 = self.sub_nonnative(&lambda_squared, &x_double);
|
||||
|
||||
let x_diff = self.sub_nonnative(x, &x3);
|
||||
let lambda_x_diff = self.mul_nonnative(&lambda, &x_diff);
|
||||
|
||||
let y3 = self.sub_nonnative(&lambda_x_diff, y);
|
||||
|
||||
AffinePointTarget { x: x3, y: y3 }
|
||||
}
|
||||
|
||||
// Add two points, which are assumed to be non-equal.
|
||||
pub fn curve_add<C: Curve>(
|
||||
&mut self,
|
||||
p1: &AffinePointTarget<C>,
|
||||
p2: &AffinePointTarget<C>,
|
||||
) -> AffinePointTarget<C> {
|
||||
let AffinePointTarget { x: x1, y: y1 } = p1;
|
||||
let AffinePointTarget { x: x2, y: y2 } = p2;
|
||||
|
||||
let u = self.sub_nonnative(y2, y1);
|
||||
let uu = self.mul_nonnative(&u, &u);
|
||||
let v = self.sub_nonnative(x2, x1);
|
||||
let vv = self.mul_nonnative(&v, &v);
|
||||
let vvv = self.mul_nonnative(&v, &vv);
|
||||
let r = self.mul_nonnative(&vv, x1);
|
||||
let diff = self.sub_nonnative(&uu, &vvv);
|
||||
let r2 = self.add_nonnative(&r, &r);
|
||||
let a = self.sub_nonnative(&diff, &r2);
|
||||
let x3 = self.mul_nonnative(&v, &a);
|
||||
|
||||
let r_a = self.sub_nonnative(&r, &a);
|
||||
let y3_first = self.mul_nonnative(&u, &r_a);
|
||||
let y3_second = self.mul_nonnative(&vvv, y1);
|
||||
let y3 = self.sub_nonnative(&y3_first, &y3_second);
|
||||
|
||||
let z3_inv = self.inv_nonnative(&vvv);
|
||||
let x3_norm = self.mul_nonnative(&x3, &z3_inv);
|
||||
let y3_norm = self.mul_nonnative(&y3, &z3_inv);
|
||||
|
||||
AffinePointTarget {
|
||||
x: x3_norm,
|
||||
y: y3_norm,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn curve_scalar_mul<C: Curve>(
|
||||
&mut self,
|
||||
p: &AffinePointTarget<C>,
|
||||
n: &NonNativeTarget<C::ScalarField>,
|
||||
) -> AffinePointTarget<C> {
|
||||
let one = self.constant_nonnative(C::BaseField::ONE);
|
||||
|
||||
let bits = self.split_nonnative_to_bits(n);
|
||||
let bits_as_base: Vec<NonNativeTarget<C::BaseField>> =
|
||||
bits.iter().map(|b| self.bool_to_nonnative(b)).collect();
|
||||
|
||||
let rando = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine();
|
||||
let randot = self.constant_affine_point(rando);
|
||||
// Result starts at `rando`, which is later subtracted, because we don't support arithmetic with the zero point.
|
||||
let mut result = self.add_virtual_affine_point_target();
|
||||
self.connect_affine_point(&randot, &result);
|
||||
|
||||
let mut two_i_times_p = self.add_virtual_affine_point_target();
|
||||
self.connect_affine_point(p, &two_i_times_p);
|
||||
|
||||
for bit in bits_as_base.iter() {
|
||||
let not_bit = self.sub_nonnative(&one, bit);
|
||||
|
||||
let result_plus_2_i_p = self.curve_add(&result, &two_i_times_p);
|
||||
|
||||
let new_x_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p.x);
|
||||
let new_x_if_not_bit = self.mul_nonnative(¬_bit, &result.x);
|
||||
let new_y_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p.y);
|
||||
let new_y_if_not_bit = self.mul_nonnative(¬_bit, &result.y);
|
||||
|
||||
let new_x = self.add_nonnative(&new_x_if_bit, &new_x_if_not_bit);
|
||||
let new_y = self.add_nonnative(&new_y_if_bit, &new_y_if_not_bit);
|
||||
|
||||
result = AffinePointTarget { x: new_x, y: new_y };
|
||||
|
||||
two_i_times_p = self.curve_double(&two_i_times_p);
|
||||
}
|
||||
|
||||
// Subtract off result's intial value of `rando`.
|
||||
let neg_r = self.curve_neg(&randot);
|
||||
result = self.curve_add(&result, &neg_r);
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
|
||||
use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar};
|
||||
use crate::curve::secp256k1::Secp256K1;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::field::secp256k1_base::Secp256K1Base;
|
||||
use crate::field::secp256k1_scalar::Secp256K1Scalar;
|
||||
use crate::iop::witness::PartialWitness;
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::circuit_data::CircuitConfig;
|
||||
use crate::plonk::verifier::verify;
|
||||
|
||||
#[test]
|
||||
fn test_curve_point_is_valid() -> Result<()> {
|
||||
type F = GoldilocksField;
|
||||
const D: usize = 4;
|
||||
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||
|
||||
let g = Secp256K1::GENERATOR_AFFINE;
|
||||
let g_target = builder.constant_affine_point(g);
|
||||
let neg_g_target = builder.curve_neg(&g_target);
|
||||
|
||||
builder.curve_assert_valid(&g_target);
|
||||
builder.curve_assert_valid(&neg_g_target);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_curve_point_is_not_valid() {
|
||||
type F = GoldilocksField;
|
||||
const D: usize = 4;
|
||||
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||
|
||||
let g = Secp256K1::GENERATOR_AFFINE;
|
||||
let not_g = AffinePoint::<Secp256K1> {
|
||||
x: g.x,
|
||||
y: g.y + Secp256K1Base::ONE,
|
||||
zero: g.zero,
|
||||
};
|
||||
let not_g_target = builder.constant_affine_point(not_g);
|
||||
|
||||
builder.curve_assert_valid(¬_g_target);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
|
||||
verify(proof, &data.verifier_only, &data.common).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_curve_double() -> Result<()> {
|
||||
type F = GoldilocksField;
|
||||
const D: usize = 4;
|
||||
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||
|
||||
let g = Secp256K1::GENERATOR_AFFINE;
|
||||
let g_target = builder.constant_affine_point(g);
|
||||
let neg_g_target = builder.curve_neg(&g_target);
|
||||
|
||||
let double_g = g.double();
|
||||
let double_g_expected = builder.constant_affine_point(double_g);
|
||||
builder.curve_assert_valid(&double_g_expected);
|
||||
|
||||
let double_neg_g = (-g).double();
|
||||
let double_neg_g_expected = builder.constant_affine_point(double_neg_g);
|
||||
builder.curve_assert_valid(&double_neg_g_expected);
|
||||
|
||||
let double_g_actual = builder.curve_double(&g_target);
|
||||
let double_neg_g_actual = builder.curve_double(&neg_g_target);
|
||||
builder.curve_assert_valid(&double_g_actual);
|
||||
builder.curve_assert_valid(&double_neg_g_actual);
|
||||
|
||||
builder.connect_affine_point(&double_g_expected, &double_g_actual);
|
||||
builder.connect_affine_point(&double_neg_g_expected, &double_neg_g_actual);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_curve_add() -> Result<()> {
|
||||
type F = GoldilocksField;
|
||||
const D: usize = 4;
|
||||
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||
|
||||
let g = Secp256K1::GENERATOR_AFFINE;
|
||||
let double_g = g.double();
|
||||
let g_plus_2g = (g + double_g).to_affine();
|
||||
let g_plus_2g_expected = builder.constant_affine_point(g_plus_2g);
|
||||
builder.curve_assert_valid(&g_plus_2g_expected);
|
||||
|
||||
let g_target = builder.constant_affine_point(g);
|
||||
let double_g_target = builder.curve_double(&g_target);
|
||||
let g_plus_2g_actual = builder.curve_add(&g_target, &double_g_target);
|
||||
builder.curve_assert_valid(&g_plus_2g_actual);
|
||||
|
||||
builder.connect_affine_point(&g_plus_2g_expected, &g_plus_2g_actual);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn test_curve_mul() -> Result<()> {
|
||||
type F = GoldilocksField;
|
||||
const D: usize = 4;
|
||||
|
||||
let config = CircuitConfig {
|
||||
num_routed_wires: 33,
|
||||
..CircuitConfig::standard_recursion_config()
|
||||
};
|
||||
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||
|
||||
let g = Secp256K1::GENERATOR_AFFINE;
|
||||
let five = Secp256K1Scalar::from_canonical_usize(5);
|
||||
let five_scalar = CurveScalar::<Secp256K1>(five);
|
||||
let five_g = (five_scalar * g.to_projective()).to_affine();
|
||||
let five_g_expected = builder.constant_affine_point(five_g);
|
||||
builder.curve_assert_valid(&five_g_expected);
|
||||
|
||||
let g_target = builder.constant_affine_point(g);
|
||||
let five_target = builder.constant_nonnative(five);
|
||||
let five_g_actual = builder.curve_scalar_mul(&g_target, &five_target);
|
||||
builder.curve_assert_valid(&five_g_actual);
|
||||
|
||||
builder.connect_affine_point(&five_g_expected, &five_g_actual);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn test_curve_random() -> Result<()> {
|
||||
type F = GoldilocksField;
|
||||
const D: usize = 4;
|
||||
|
||||
let config = CircuitConfig {
|
||||
num_routed_wires: 33,
|
||||
..CircuitConfig::standard_recursion_config()
|
||||
};
|
||||
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||
|
||||
let rando =
|
||||
(CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine();
|
||||
let randot = builder.constant_affine_point(rando);
|
||||
|
||||
let two_target = builder.constant_nonnative(Secp256K1Scalar::TWO);
|
||||
let randot_doubled = builder.curve_double(&randot);
|
||||
let randot_times_two = builder.curve_scalar_mul(&randot, &two_target);
|
||||
builder.connect_affine_point(&randot_doubled, &randot_times_two);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
}
|
||||
@ -1,22 +1,94 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::RichField;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::interpolation::InterpolationGate;
|
||||
use crate::iop::target::Target;
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
|
||||
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
/// Trait for gates which interpolate a polynomial, whose points are a (base field) coset of the multiplicative subgroup
|
||||
/// with the given size, and whose values are extension field elements, given by input wires.
|
||||
/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point.
|
||||
pub(crate) trait InterpolationGate<F: Extendable<D>, const D: usize>:
|
||||
Gate<F, D> + Copy
|
||||
{
|
||||
fn new(subgroup_bits: usize) -> Self;
|
||||
|
||||
fn num_points(&self) -> usize;
|
||||
|
||||
/// Wire index of the coset shift.
|
||||
fn wire_shift(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn start_values(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
/// Wire indices of the `i`th interpolant value.
|
||||
fn wires_value(&self, i: usize) -> Range<usize> {
|
||||
debug_assert!(i < self.num_points());
|
||||
let start = self.start_values() + i * D;
|
||||
start..start + D
|
||||
}
|
||||
|
||||
fn start_evaluation_point(&self) -> usize {
|
||||
self.start_values() + self.num_points() * D
|
||||
}
|
||||
|
||||
/// Wire indices of the point to evaluate the interpolant at.
|
||||
fn wires_evaluation_point(&self) -> Range<usize> {
|
||||
let start = self.start_evaluation_point();
|
||||
start..start + D
|
||||
}
|
||||
|
||||
fn start_evaluation_value(&self) -> usize {
|
||||
self.start_evaluation_point() + D
|
||||
}
|
||||
|
||||
/// Wire indices of the interpolated value.
|
||||
fn wires_evaluation_value(&self) -> Range<usize> {
|
||||
let start = self.start_evaluation_value();
|
||||
start..start + D
|
||||
}
|
||||
|
||||
fn start_coeffs(&self) -> usize {
|
||||
self.start_evaluation_value() + D
|
||||
}
|
||||
|
||||
/// The number of routed wires required in the typical usage of this gate, where the points to
|
||||
/// interpolate, the evaluation point, and the corresponding value are all routed.
|
||||
fn num_routed_wires(&self) -> usize {
|
||||
self.start_coeffs()
|
||||
}
|
||||
|
||||
/// Wire indices of the interpolant's `i`th coefficient.
|
||||
fn wires_coeff(&self, i: usize) -> Range<usize> {
|
||||
debug_assert!(i < self.num_points());
|
||||
let start = self.start_coeffs() + i * D;
|
||||
start..start + D
|
||||
}
|
||||
|
||||
fn end_coeffs(&self) -> usize {
|
||||
self.start_coeffs() + D * self.num_points()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
/// Interpolates a polynomial, whose points are a coset of the multiplicative subgroup with the
|
||||
/// given size, and whose values are given. Returns the evaluation of the interpolant at
|
||||
/// `evaluation_point`.
|
||||
pub fn interpolate_coset(
|
||||
pub(crate) fn interpolate_coset<G: InterpolationGate<F, D>>(
|
||||
&mut self,
|
||||
subgroup_bits: usize,
|
||||
coset_shift: Target,
|
||||
values: &[ExtensionTarget<D>],
|
||||
evaluation_point: ExtensionTarget<D>,
|
||||
) -> ExtensionTarget<D> {
|
||||
let gate = InterpolationGate::new(subgroup_bits);
|
||||
let gate_index = self.add_gate(gate.clone(), vec![]);
|
||||
let gate = G::new(subgroup_bits);
|
||||
let gate_index = self.add_gate(gate, vec![]);
|
||||
self.connect(coset_shift, Target::wire(gate_index, gate.wire_shift()));
|
||||
for (i, &v) in values.iter().enumerate() {
|
||||
self.connect_extension(
|
||||
@ -37,6 +109,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
|
||||
use crate::field::extension_field::quadratic::QuadraticExtension;
|
||||
use crate::field::extension_field::FieldExtension;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::interpolation::interpolant;
|
||||
@ -83,9 +156,21 @@ mod tests {
|
||||
|
||||
let zt = builder.constant_extension(z);
|
||||
|
||||
let eval = builder.interpolate_coset(subgroup_bits, coset_shift_target, &value_targets, zt);
|
||||
let eval_hd = builder.interpolate_coset::<HighDegreeInterpolationGate<F, D>>(
|
||||
subgroup_bits,
|
||||
coset_shift_target,
|
||||
&value_targets,
|
||||
zt,
|
||||
);
|
||||
let eval_ld = builder.interpolate_coset::<LowDegreeInterpolationGate<F, D>>(
|
||||
subgroup_bits,
|
||||
coset_shift_target,
|
||||
&value_targets,
|
||||
zt,
|
||||
);
|
||||
let true_eval_target = builder.constant_extension(true_eval);
|
||||
builder.connect_extension(eval, true_eval_target);
|
||||
builder.connect_extension(eval_hd, true_eval_target);
|
||||
builder.connect_extension(eval_ld, true_eval_target);
|
||||
|
||||
let data = builder.build::<C>();
|
||||
let proof = data.prove(pw)?;
|
||||
|
||||
@ -1,8 +1,13 @@
|
||||
pub mod arithmetic;
|
||||
pub mod arithmetic_extension;
|
||||
pub mod arithmetic_u32;
|
||||
pub mod biguint;
|
||||
pub mod curve;
|
||||
pub mod hash;
|
||||
pub mod insert;
|
||||
pub mod interpolation;
|
||||
pub mod multiple_comparison;
|
||||
pub mod nonnative;
|
||||
pub mod permutation;
|
||||
pub mod polynomial;
|
||||
pub mod random_access;
|
||||
|
||||
138
src/gadgets/multiple_comparison.rs
Normal file
138
src/gadgets/multiple_comparison.rs
Normal file
@ -0,0 +1,138 @@
|
||||
use super::arithmetic_u32::U32Target;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::RichField;
|
||||
use crate::gates::comparison::ComparisonGate;
|
||||
use crate::iop::target::{BoolTarget, Target};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::util::ceil_div_usize;
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
/// Returns true if a is less than or equal to b, considered as base-`2^num_bits` limbs of a large value.
|
||||
/// This range-checks its inputs.
|
||||
pub fn list_le(&mut self, a: Vec<Target>, b: Vec<Target>, num_bits: usize) -> BoolTarget {
|
||||
assert_eq!(
|
||||
a.len(),
|
||||
b.len(),
|
||||
"Comparison must be between same number of inputs and outputs"
|
||||
);
|
||||
let n = a.len();
|
||||
|
||||
let chunk_bits = 2;
|
||||
let num_chunks = ceil_div_usize(num_bits, chunk_bits);
|
||||
|
||||
let one = self.one();
|
||||
let mut result = one;
|
||||
for i in 0..n {
|
||||
let a_le_b_gate = ComparisonGate::new(num_bits, num_chunks);
|
||||
let a_le_b_gate_index = self.add_gate(a_le_b_gate.clone(), vec![]);
|
||||
self.connect(
|
||||
Target::wire(a_le_b_gate_index, a_le_b_gate.wire_first_input()),
|
||||
a[i],
|
||||
);
|
||||
self.connect(
|
||||
Target::wire(a_le_b_gate_index, a_le_b_gate.wire_second_input()),
|
||||
b[i],
|
||||
);
|
||||
let a_le_b_result = Target::wire(a_le_b_gate_index, a_le_b_gate.wire_result_bool());
|
||||
|
||||
let b_le_a_gate = ComparisonGate::new(num_bits, num_chunks);
|
||||
let b_le_a_gate_index = self.add_gate(b_le_a_gate.clone(), vec![]);
|
||||
self.connect(
|
||||
Target::wire(b_le_a_gate_index, b_le_a_gate.wire_first_input()),
|
||||
b[i],
|
||||
);
|
||||
self.connect(
|
||||
Target::wire(b_le_a_gate_index, b_le_a_gate.wire_second_input()),
|
||||
a[i],
|
||||
);
|
||||
let b_le_a_result = Target::wire(b_le_a_gate_index, b_le_a_gate.wire_result_bool());
|
||||
|
||||
let these_limbs_equal = self.mul(a_le_b_result, b_le_a_result);
|
||||
let these_limbs_less_than = self.sub(one, b_le_a_result);
|
||||
result = self.mul_add(these_limbs_equal, result, these_limbs_less_than);
|
||||
}
|
||||
|
||||
// `result` being boolean is an invariant, maintained because its new value is always
|
||||
// `x * result + y`, where `x` and `y` are booleans that are not simultaneously true.
|
||||
BoolTarget::new_unsafe(result)
|
||||
}
|
||||
|
||||
/// Helper function for comparing, specifically, lists of `U32Target`s.
|
||||
pub fn list_le_u32(&mut self, a: Vec<U32Target>, b: Vec<U32Target>) -> BoolTarget {
|
||||
let a_targets = a.iter().map(|&t| t.0).collect();
|
||||
let b_targets = b.iter().map(|&t| t.0).collect();
|
||||
self.list_le(a_targets, b_targets, 32)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use num::BigUint;
|
||||
use rand::Rng;
|
||||
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::iop::witness::PartialWitness;
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::circuit_data::CircuitConfig;
|
||||
use crate::plonk::verifier::verify;
|
||||
|
||||
fn test_list_le(size: usize, num_bits: usize) -> Result<()> {
|
||||
type F = GoldilocksField;
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, 4>::new(config);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let lst1: Vec<u64> = (0..size)
|
||||
.map(|_| rng.gen_range(0..(1 << num_bits)))
|
||||
.collect();
|
||||
let lst2: Vec<u64> = (0..size)
|
||||
.map(|_| rng.gen_range(0..(1 << num_bits)))
|
||||
.collect();
|
||||
|
||||
let a_biguint = BigUint::from_slice(
|
||||
&lst1
|
||||
.iter()
|
||||
.flat_map(|&x| [x as u32, (x >> 32) as u32])
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let b_biguint = BigUint::from_slice(
|
||||
&lst2
|
||||
.iter()
|
||||
.flat_map(|&x| [x as u32, (x >> 32) as u32])
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let a = lst1
|
||||
.iter()
|
||||
.map(|&x| builder.constant(F::from_canonical_u64(x)))
|
||||
.collect();
|
||||
let b = lst2
|
||||
.iter()
|
||||
.map(|&x| builder.constant(F::from_canonical_u64(x)))
|
||||
.collect();
|
||||
|
||||
let result = builder.list_le(a, b, num_bits);
|
||||
|
||||
let expected_result = builder.constant_bool(a_biguint <= b_biguint);
|
||||
builder.connect(result.target, expected_result.target);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_comparison() -> Result<()> {
|
||||
for size in [1, 3, 6] {
|
||||
for num_bits in [20, 32, 40, 44] {
|
||||
test_list_le(size, num_bits).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
342
src/gadgets/nonnative.rs
Normal file
342
src/gadgets/nonnative.rs
Normal file
@ -0,0 +1,342 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use num::{BigUint, Zero};
|
||||
|
||||
use crate::field::field_types::RichField;
|
||||
use crate::field::{extension_field::Extendable, field_types::Field};
|
||||
use crate::gadgets::arithmetic_u32::U32Target;
|
||||
use crate::gadgets::biguint::BigUintTarget;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
|
||||
use crate::iop::target::{BoolTarget, Target};
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::util::ceil_div_usize;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NonNativeTarget<FF: Field> {
|
||||
pub(crate) value: BigUintTarget,
|
||||
_phantom: PhantomData<FF>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
fn num_nonnative_limbs<FF: Field>() -> usize {
|
||||
ceil_div_usize(FF::BITS, 32)
|
||||
}
|
||||
|
||||
pub fn biguint_to_nonnative<FF: Field>(&mut self, x: &BigUintTarget) -> NonNativeTarget<FF> {
|
||||
NonNativeTarget {
|
||||
value: x.clone(),
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn nonnative_to_biguint<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> BigUintTarget {
|
||||
x.value.clone()
|
||||
}
|
||||
|
||||
pub fn constant_nonnative<FF: Field>(&mut self, x: FF) -> NonNativeTarget<FF> {
|
||||
let x_biguint = self.constant_biguint(&x.to_biguint());
|
||||
self.biguint_to_nonnative(&x_biguint)
|
||||
}
|
||||
|
||||
// Assert that two NonNativeTarget's, both assumed to be in reduced form, are equal.
|
||||
pub fn connect_nonnative<FF: Field>(
|
||||
&mut self,
|
||||
lhs: &NonNativeTarget<FF>,
|
||||
rhs: &NonNativeTarget<FF>,
|
||||
) {
|
||||
self.connect_biguint(&lhs.value, &rhs.value);
|
||||
}
|
||||
|
||||
pub fn add_virtual_nonnative_target<FF: Field>(&mut self) -> NonNativeTarget<FF> {
|
||||
let num_limbs = Self::num_nonnative_limbs::<FF>();
|
||||
let value = self.add_virtual_biguint_target(num_limbs);
|
||||
|
||||
NonNativeTarget {
|
||||
value,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
// Add two `NonNativeTarget`s.
|
||||
pub fn add_nonnative<FF: Field>(
|
||||
&mut self,
|
||||
a: &NonNativeTarget<FF>,
|
||||
b: &NonNativeTarget<FF>,
|
||||
) -> NonNativeTarget<FF> {
|
||||
let result = self.add_biguint(&a.value, &b.value);
|
||||
|
||||
// TODO: reduce add result with only one conditional subtraction
|
||||
self.reduce(&result)
|
||||
}
|
||||
|
||||
// Subtract two `NonNativeTarget`s.
|
||||
pub fn sub_nonnative<FF: Field>(
|
||||
&mut self,
|
||||
a: &NonNativeTarget<FF>,
|
||||
b: &NonNativeTarget<FF>,
|
||||
) -> NonNativeTarget<FF> {
|
||||
let order = self.constant_biguint(&FF::order());
|
||||
let a_plus_order = self.add_biguint(&order, &a.value);
|
||||
let result = self.sub_biguint(&a_plus_order, &b.value);
|
||||
|
||||
// TODO: reduce sub result with only one conditional addition?
|
||||
self.reduce(&result)
|
||||
}
|
||||
|
||||
pub fn mul_nonnative<FF: Field>(
|
||||
&mut self,
|
||||
a: &NonNativeTarget<FF>,
|
||||
b: &NonNativeTarget<FF>,
|
||||
) -> NonNativeTarget<FF> {
|
||||
let result = self.mul_biguint(&a.value, &b.value);
|
||||
|
||||
self.reduce(&result)
|
||||
}
|
||||
|
||||
pub fn neg_nonnative<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> NonNativeTarget<FF> {
|
||||
let zero_target = self.constant_biguint(&BigUint::zero());
|
||||
let zero_ff = self.biguint_to_nonnative(&zero_target);
|
||||
|
||||
self.sub_nonnative(&zero_ff, x)
|
||||
}
|
||||
|
||||
pub fn inv_nonnative<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> NonNativeTarget<FF> {
|
||||
let num_limbs = x.value.num_limbs();
|
||||
let inv_biguint = self.add_virtual_biguint_target(num_limbs);
|
||||
let inv = NonNativeTarget::<FF> {
|
||||
value: inv_biguint,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
|
||||
self.add_simple_generator(NonNativeInverseGenerator::<F, D, FF> {
|
||||
x: x.clone(),
|
||||
inv: inv.clone(),
|
||||
_phantom: PhantomData,
|
||||
});
|
||||
|
||||
let product = self.mul_nonnative(x, &inv);
|
||||
let one = self.constant_nonnative(FF::ONE);
|
||||
self.connect_nonnative(&product, &one);
|
||||
|
||||
inv
|
||||
}
|
||||
|
||||
pub fn div_rem_nonnative<FF: Field>(
|
||||
&mut self,
|
||||
x: &NonNativeTarget<FF>,
|
||||
y: &NonNativeTarget<FF>,
|
||||
) -> (NonNativeTarget<FF>, NonNativeTarget<FF>) {
|
||||
let x_biguint = self.nonnative_to_biguint(x);
|
||||
let y_biguint = self.nonnative_to_biguint(y);
|
||||
|
||||
let (div_biguint, rem_biguint) = self.div_rem_biguint(&x_biguint, &y_biguint);
|
||||
let div = self.biguint_to_nonnative(&div_biguint);
|
||||
let rem = self.biguint_to_nonnative(&rem_biguint);
|
||||
(div, rem)
|
||||
}
|
||||
|
||||
/// Returns `x % |FF|` as a `NonNativeTarget`.
|
||||
fn reduce<FF: Field>(&mut self, x: &BigUintTarget) -> NonNativeTarget<FF> {
|
||||
let modulus = FF::order();
|
||||
let order_target = self.constant_biguint(&modulus);
|
||||
let value = self.rem_biguint(x, &order_target);
|
||||
|
||||
NonNativeTarget {
|
||||
value,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn reduce_nonnative<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> NonNativeTarget<FF> {
|
||||
let x_biguint = self.nonnative_to_biguint(x);
|
||||
self.reduce(&x_biguint)
|
||||
}
|
||||
|
||||
pub fn bool_to_nonnative<FF: Field>(&mut self, b: &BoolTarget) -> NonNativeTarget<FF> {
|
||||
let limbs = vec![U32Target(b.target)];
|
||||
let value = BigUintTarget { limbs };
|
||||
|
||||
NonNativeTarget {
|
||||
value,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
// Split a nonnative field element to bits.
|
||||
pub fn split_nonnative_to_bits<FF: Field>(
|
||||
&mut self,
|
||||
x: &NonNativeTarget<FF>,
|
||||
) -> Vec<BoolTarget> {
|
||||
let num_limbs = x.value.num_limbs();
|
||||
let mut result = Vec::with_capacity(num_limbs * 32);
|
||||
|
||||
for i in 0..num_limbs {
|
||||
let limb = x.value.get_limb(i);
|
||||
let bit_targets = self.split_le_base::<2>(limb.0, 32);
|
||||
let mut bits: Vec<_> = bit_targets
|
||||
.iter()
|
||||
.map(|&t| BoolTarget::new_unsafe(t))
|
||||
.collect();
|
||||
|
||||
result.append(&mut bits);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct NonNativeInverseGenerator<F: RichField + Extendable<D>, const D: usize, FF: Field> {
|
||||
x: NonNativeTarget<FF>,
|
||||
inv: NonNativeTarget<FF>,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize, FF: Field> SimpleGenerator<F>
|
||||
for NonNativeInverseGenerator<F, D, FF>
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
self.x.value.limbs.iter().map(|&l| l.0).collect()
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
let x = witness.get_nonnative_target(self.x.clone());
|
||||
let inv = x.inverse();
|
||||
|
||||
out_buffer.set_nonnative_target(self.inv.clone(), inv);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::field::secp256k1_base::Secp256K1Base;
|
||||
use crate::iop::witness::PartialWitness;
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::circuit_data::CircuitConfig;
|
||||
use crate::plonk::verifier::verify;
|
||||
|
||||
#[test]
|
||||
fn test_nonnative_add() -> Result<()> {
|
||||
type FF = Secp256K1Base;
|
||||
let x_ff = FF::rand();
|
||||
let y_ff = FF::rand();
|
||||
let sum_ff = x_ff + y_ff;
|
||||
|
||||
type F = GoldilocksField;
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, 4>::new(config);
|
||||
|
||||
let x = builder.constant_nonnative(x_ff);
|
||||
let y = builder.constant_nonnative(y_ff);
|
||||
let sum = builder.add_nonnative(&x, &y);
|
||||
|
||||
let sum_expected = builder.constant_nonnative(sum_ff);
|
||||
builder.connect_nonnative(&sum, &sum_expected);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonnative_sub() -> Result<()> {
|
||||
type FF = Secp256K1Base;
|
||||
let x_ff = FF::rand();
|
||||
let mut y_ff = FF::rand();
|
||||
while y_ff.to_biguint() > x_ff.to_biguint() {
|
||||
y_ff = FF::rand();
|
||||
}
|
||||
let diff_ff = x_ff - y_ff;
|
||||
|
||||
type F = GoldilocksField;
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, 4>::new(config);
|
||||
|
||||
let x = builder.constant_nonnative(x_ff);
|
||||
let y = builder.constant_nonnative(y_ff);
|
||||
let diff = builder.sub_nonnative(&x, &y);
|
||||
|
||||
let diff_expected = builder.constant_nonnative(diff_ff);
|
||||
builder.connect_nonnative(&diff, &diff_expected);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonnative_mul() -> Result<()> {
|
||||
type FF = Secp256K1Base;
|
||||
let x_ff = FF::rand();
|
||||
let y_ff = FF::rand();
|
||||
let product_ff = x_ff * y_ff;
|
||||
|
||||
type F = GoldilocksField;
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, 4>::new(config);
|
||||
|
||||
let x = builder.constant_nonnative(x_ff);
|
||||
let y = builder.constant_nonnative(y_ff);
|
||||
let product = builder.mul_nonnative(&x, &y);
|
||||
|
||||
let product_expected = builder.constant_nonnative(product_ff);
|
||||
builder.connect_nonnative(&product, &product_expected);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonnative_neg() -> Result<()> {
|
||||
type FF = Secp256K1Base;
|
||||
let x_ff = FF::rand();
|
||||
let neg_x_ff = -x_ff;
|
||||
|
||||
type F = GoldilocksField;
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, 4>::new(config);
|
||||
|
||||
let x = builder.constant_nonnative(x_ff);
|
||||
let neg_x = builder.neg_nonnative(&x);
|
||||
|
||||
let neg_x_expected = builder.constant_nonnative(neg_x_ff);
|
||||
builder.connect_nonnative(&neg_x, &neg_x_expected);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonnative_inv() -> Result<()> {
|
||||
type FF = Secp256K1Base;
|
||||
let x_ff = FF::rand();
|
||||
let inv_x_ff = x_ff.inverse();
|
||||
|
||||
type F = GoldilocksField;
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, 4>::new(config);
|
||||
|
||||
let x = builder.constant_nonnative(x_ff);
|
||||
let inv_x = builder.inv_nonnative(&x);
|
||||
|
||||
let inv_x_expected = builder.constant_nonnative(inv_x_ff);
|
||||
builder.connect_nonnative(&inv_x, &inv_x_expected);
|
||||
|
||||
let data = builder.build();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
verify(proof, &data.verifier_only, &data.common)
|
||||
}
|
||||
}
|
||||
@ -2,7 +2,6 @@ use std::collections::BTreeMap;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::field::{extension_field::Extendable, field_types::Field};
|
||||
use crate::gates::switch::SwitchGate;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
@ -34,7 +33,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
self.assert_permutation_2x2(a[0].clone(), a[1].clone(), b[0].clone(), b[1].clone())
|
||||
}
|
||||
// For larger lists, we recursively use two smaller permutation networks.
|
||||
//_ => self.assert_permutation_recursive(a, b)
|
||||
_ => self.assert_permutation_recursive(a, b),
|
||||
}
|
||||
}
|
||||
@ -72,22 +70,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
|
||||
let chunk_size = a1.len();
|
||||
|
||||
if self.current_switch_gates.len() < chunk_size {
|
||||
self.current_switch_gates
|
||||
.extend(vec![None; chunk_size - self.current_switch_gates.len()]);
|
||||
}
|
||||
|
||||
let (gate, gate_index, mut next_copy) =
|
||||
match self.current_switch_gates[chunk_size - 1].clone() {
|
||||
None => {
|
||||
let gate = SwitchGate::<F, D>::new_from_config(&self.config, chunk_size);
|
||||
let gate_index = self.add_gate(gate.clone(), vec![]);
|
||||
(gate, gate_index, 0)
|
||||
}
|
||||
Some((gate, idx, next_copy)) => (gate, idx, next_copy),
|
||||
};
|
||||
|
||||
let num_copies = gate.num_copies;
|
||||
let (gate, gate_index, next_copy) = self.find_switch_gate(chunk_size);
|
||||
|
||||
let mut c = Vec::new();
|
||||
let mut d = Vec::new();
|
||||
@ -112,13 +95,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
|
||||
let switch = Target::wire(gate_index, gate.wire_switch_bool(next_copy));
|
||||
|
||||
next_copy += 1;
|
||||
if next_copy == num_copies {
|
||||
self.current_switch_gates[chunk_size - 1] = None;
|
||||
} else {
|
||||
self.current_switch_gates[chunk_size - 1] = Some((gate, gate_index, next_copy));
|
||||
}
|
||||
|
||||
(switch, c, d)
|
||||
}
|
||||
|
||||
@ -402,7 +378,7 @@ mod tests {
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||
|
||||
let lst: Vec<F> = (0..size * 2).map(|n| F::from_canonical_usize(n)).collect();
|
||||
let lst: Vec<F> = (0..size * 2).map(F::from_canonical_usize).collect();
|
||||
let a: Vec<Vec<Target>> = lst[..]
|
||||
.chunks(2)
|
||||
.map(|pair| vec![builder.constant(pair[0]), builder.constant(pair[1])])
|
||||
|
||||
@ -63,4 +63,21 @@ impl<const D: usize> PolynomialCoeffsExtAlgebraTarget<D> {
|
||||
}
|
||||
acc
|
||||
}
|
||||
|
||||
/// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1.
|
||||
pub fn eval_with_powers<F>(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
powers: &[ExtensionAlgebraTarget<D>],
|
||||
) -> ExtensionAlgebraTarget<D>
|
||||
where
|
||||
F: RichField + Extendable<D>,
|
||||
{
|
||||
debug_assert_eq!(self.0.len(), powers.len() + 1);
|
||||
let acc = self.0[0];
|
||||
self.0[1..]
|
||||
.iter()
|
||||
.zip(powers)
|
||||
.fold(acc, |acc, (&x, &c)| builder.mul_add_ext_algebra(c, x, acc))
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,49 +3,20 @@ use crate::field::extension_field::Extendable;
|
||||
use crate::gates::random_access::RandomAccessGate;
|
||||
use crate::iop::target::Target;
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::util::log2_strict;
|
||||
|
||||
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
/// Finds the last available random access gate with the given `vec_size` or add one if there aren't any.
|
||||
/// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index
|
||||
/// `g` and the gate's `i`-th random access is available.
|
||||
fn find_random_access_gate(&mut self, vec_size: usize) -> (usize, usize) {
|
||||
let (gate, i) = self
|
||||
.free_random_access
|
||||
.get(&vec_size)
|
||||
.copied()
|
||||
.unwrap_or_else(|| {
|
||||
let gate = self.add_gate(
|
||||
RandomAccessGate::new_from_config(&self.config, vec_size),
|
||||
vec![],
|
||||
);
|
||||
(gate, 0)
|
||||
});
|
||||
|
||||
// Update `free_random_access` with new values.
|
||||
if i < RandomAccessGate::<F, D>::max_num_copies(
|
||||
self.config.num_routed_wires,
|
||||
self.config.num_wires,
|
||||
vec_size,
|
||||
) - 1
|
||||
{
|
||||
self.free_random_access.insert(vec_size, (gate, i + 1));
|
||||
} else {
|
||||
self.free_random_access.remove(&vec_size);
|
||||
}
|
||||
|
||||
(gate, i)
|
||||
}
|
||||
|
||||
/// Checks that a `Target` matches a vector at a non-deterministic index.
|
||||
/// Note: `access_index` is not range-checked.
|
||||
pub fn random_access(&mut self, access_index: Target, claimed_element: Target, v: Vec<Target>) {
|
||||
let vec_size = v.len();
|
||||
let bits = log2_strict(vec_size);
|
||||
debug_assert!(vec_size > 0);
|
||||
if vec_size == 1 {
|
||||
return self.connect(claimed_element, v[0]);
|
||||
}
|
||||
let (gate_index, copy) = self.find_random_access_gate(vec_size);
|
||||
let dummy_gate = RandomAccessGate::<F, D>::new_from_config(&self.config, vec_size);
|
||||
let (gate_index, copy) = self.find_random_access_gate(bits);
|
||||
let dummy_gate = RandomAccessGate::<F, D>::new_from_config(&self.config, bits);
|
||||
|
||||
v.iter().enumerate().for_each(|(i, &val)| {
|
||||
self.connect(
|
||||
|
||||
@ -3,6 +3,8 @@ use std::marker::PhantomData;
|
||||
use itertools::izip;
|
||||
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::gates::assert_le::AssertLessThanGate;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::gates::comparison::ComparisonGate;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
|
||||
@ -40,9 +42,9 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
self.assert_permutation(a_chunks, b_chunks);
|
||||
}
|
||||
|
||||
/// Add a ComparisonGate to assert that `lhs` is less than `rhs`, where their values are at most `bits` bits.
|
||||
/// Add an AssertLessThanGate to assert that `lhs` is less than `rhs`, where their values are at most `bits` bits.
|
||||
pub fn assert_le(&mut self, lhs: Target, rhs: Target, bits: usize, num_chunks: usize) {
|
||||
let gate = ComparisonGate::new(bits, num_chunks);
|
||||
let gate = AssertLessThanGate::new(bits, num_chunks);
|
||||
let gate_index = self.add_gate(gate.clone(), vec![]);
|
||||
|
||||
self.connect(Target::wire(gate_index, gate.wire_first_input()), lhs);
|
||||
@ -126,8 +128,7 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for MemoryOpSortGenera
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
self.input_ops
|
||||
.iter()
|
||||
.map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value])
|
||||
.flatten()
|
||||
.flat_map(|op| vec![op.is_write.target, op.address, op.timestamp, op.value])
|
||||
.collect()
|
||||
}
|
||||
|
||||
@ -223,7 +224,7 @@ mod tests {
|
||||
izip!(is_write_vals, address_vals, timestamp_vals, value_vals)
|
||||
.zip(combined_vals_u64)
|
||||
.collect::<Vec<_>>();
|
||||
input_ops_and_keys.sort_by_key(|(_, val)| val.clone());
|
||||
input_ops_and_keys.sort_by_key(|(_, val)| *val);
|
||||
let input_ops_sorted: Vec<_> = input_ops_and_keys.iter().map(|(x, _)| x).collect();
|
||||
|
||||
let output_ops =
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
use std::borrow::Borrow;
|
||||
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::gates::base_sum::BaseSumGate;
|
||||
@ -27,23 +29,25 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
|
||||
/// Takes an iterator of bits `(b_i)` and returns `sum b_i * 2^i`, i.e.,
|
||||
/// the number with little-endian bit representation given by `bits`.
|
||||
pub(crate) fn le_sum(
|
||||
&mut self,
|
||||
bits: impl ExactSizeIterator<Item = impl Borrow<BoolTarget>> + Clone,
|
||||
) -> Target {
|
||||
pub(crate) fn le_sum(&mut self, bits: impl Iterator<Item = impl Borrow<BoolTarget>>) -> Target {
|
||||
let bits = bits.map(|b| *b.borrow()).collect_vec();
|
||||
let num_bits = bits.len();
|
||||
if num_bits == 0 {
|
||||
return self.zero();
|
||||
} else if num_bits == 1 {
|
||||
let mut bits = bits;
|
||||
return bits.next().unwrap().borrow().target;
|
||||
} else if num_bits == 2 {
|
||||
let two = self.two();
|
||||
let mut bits = bits;
|
||||
let b0 = bits.next().unwrap().borrow().target;
|
||||
let b1 = bits.next().unwrap().borrow().target;
|
||||
return self.mul_add(two, b1, b0);
|
||||
}
|
||||
|
||||
// Check if it's cheaper to just do this with arithmetic operations.
|
||||
let arithmetic_ops = num_bits - 1;
|
||||
if arithmetic_ops <= self.num_base_arithmetic_ops_per_gate() {
|
||||
let two = self.two();
|
||||
let mut rev_bits = bits.iter().rev();
|
||||
let mut sum = rev_bits.next().unwrap().target;
|
||||
for &bit in rev_bits {
|
||||
sum = self.mul_add(two, sum, bit.target);
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
debug_assert!(
|
||||
BaseSumGate::<2>::START_LIMBS + num_bits <= self.config.num_routed_wires,
|
||||
"Not enough routed wires."
|
||||
@ -51,10 +55,10 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
let gate_type = BaseSumGate::<2>::new_from_config::<F>(&self.config);
|
||||
let gate_index = self.add_gate(gate_type, vec![]);
|
||||
for (limb, wire) in bits
|
||||
.clone()
|
||||
.iter()
|
||||
.zip(BaseSumGate::<2>::START_LIMBS..BaseSumGate::<2>::START_LIMBS + num_bits)
|
||||
{
|
||||
self.connect(limb.borrow().target, Target::wire(gate_index, wire));
|
||||
self.connect(limb.target, Target::wire(gate_index, wire));
|
||||
}
|
||||
for l in gate_type.limbs().skip(num_bits) {
|
||||
self.assert_zero(Target::wire(gate_index, l));
|
||||
@ -62,7 +66,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
|
||||
self.add_simple_generator(BaseSumGenerator::<2> {
|
||||
gate_index,
|
||||
limbs: bits.map(|l| *l.borrow()).collect(),
|
||||
limbs: bits,
|
||||
});
|
||||
|
||||
Target::wire(gate_index, BaseSumGate::<2>::WIRE_SUM)
|
||||
@ -146,14 +150,14 @@ mod tests {
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||
|
||||
let n = thread_rng().gen_range(0..(1 << 10));
|
||||
let n = thread_rng().gen_range(0..(1 << 30));
|
||||
let x = builder.constant(F::from_canonical_usize(n));
|
||||
|
||||
let zero = builder._false();
|
||||
let one = builder._true();
|
||||
|
||||
let y = builder.le_sum(
|
||||
(0..10)
|
||||
(0..30)
|
||||
.scan(n, |acc, _| {
|
||||
let tmp = *acc % 2;
|
||||
*acc /= 2;
|
||||
|
||||
@ -24,8 +24,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
|
||||
let mut bits = Vec::with_capacity(num_bits);
|
||||
for &gate in &gates {
|
||||
let start_limbs = BaseSumGate::<2>::START_LIMBS;
|
||||
for limb_input in start_limbs..start_limbs + gate_type.num_limbs {
|
||||
for limb_input in gate_type.limbs() {
|
||||
// `new_unsafe` is safe here because BaseSumGate::<2> forces it to be in `{0, 1}`.
|
||||
bits.push(BoolTarget::new_unsafe(Target::wire(gate, limb_input)));
|
||||
}
|
||||
@ -35,10 +34,11 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
}
|
||||
|
||||
let zero = self.zero();
|
||||
let base = F::TWO.exp_u64(gate_type.num_limbs as u64);
|
||||
let mut acc = zero;
|
||||
for &gate in gates.iter().rev() {
|
||||
let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM);
|
||||
acc = self.mul_const_add(F::from_canonical_usize(1 << gate_type.num_limbs), acc, sum);
|
||||
acc = self.mul_const_add(base, acc, sum);
|
||||
}
|
||||
self.connect(acc, integer);
|
||||
|
||||
@ -96,11 +96,18 @@ impl<F: RichField> SimpleGenerator<F> for WireSplitGenerator {
|
||||
|
||||
for &gate in &self.gates {
|
||||
let sum = Target::wire(gate, BaseSumGate::<2>::WIRE_SUM);
|
||||
out_buffer.set_target(
|
||||
sum,
|
||||
F::from_canonical_u64(integer_value & ((1 << self.num_limbs) - 1)),
|
||||
);
|
||||
integer_value >>= self.num_limbs;
|
||||
|
||||
// If num_limbs >= 64, we don't need to truncate since `integer_value` is already
|
||||
// limited to 64 bits, and trying to do so would cause overflow. Hence the conditional.
|
||||
let mut truncated_value = integer_value;
|
||||
if self.num_limbs < 64 {
|
||||
truncated_value = integer_value & ((1 << self.num_limbs) - 1);
|
||||
integer_value >>= self.num_limbs;
|
||||
} else {
|
||||
integer_value = 0;
|
||||
};
|
||||
|
||||
out_buffer.set_target(sum, F::from_canonical_u64(truncated_value));
|
||||
}
|
||||
|
||||
debug_assert_eq!(
|
||||
|
||||
212
src/gates/arithmetic_base.rs
Normal file
212
src/gates/arithmetic_base.rs
Normal file
@ -0,0 +1,212 @@
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::RichField;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::circuit_data::CircuitConfig;
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
|
||||
/// A gate which can perform a weighted multiply-add, i.e. `result = c0 x y + c1 z`. If the config
|
||||
/// supports enough routed wires, it can support several such operations in one gate.
|
||||
#[derive(Debug)]
|
||||
pub struct ArithmeticGate {
|
||||
/// Number of arithmetic operations performed by an arithmetic gate.
|
||||
pub num_ops: usize,
|
||||
}
|
||||
|
||||
impl ArithmeticGate {
|
||||
pub fn new_from_config(config: &CircuitConfig) -> Self {
|
||||
Self {
|
||||
num_ops: Self::num_ops(config),
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine the maximum number of operations that can fit in one gate for the given config.
|
||||
pub(crate) fn num_ops(config: &CircuitConfig) -> usize {
|
||||
let wires_per_op = 4;
|
||||
config.num_routed_wires / wires_per_op
|
||||
}
|
||||
|
||||
pub fn wire_ith_multiplicand_0(i: usize) -> usize {
|
||||
4 * i
|
||||
}
|
||||
pub fn wire_ith_multiplicand_1(i: usize) -> usize {
|
||||
4 * i + 1
|
||||
}
|
||||
pub fn wire_ith_addend(i: usize) -> usize {
|
||||
4 * i + 2
|
||||
}
|
||||
pub fn wire_ith_output(i: usize) -> usize {
|
||||
4 * i + 3
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ArithmeticGate {
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}", self)
|
||||
}
|
||||
|
||||
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
|
||||
let const_0 = vars.local_constants[0];
|
||||
let const_1 = vars.local_constants[1];
|
||||
|
||||
let mut constraints = Vec::new();
|
||||
for i in 0..self.num_ops {
|
||||
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
|
||||
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
|
||||
let addend = vars.local_wires[Self::wire_ith_addend(i)];
|
||||
let output = vars.local_wires[Self::wire_ith_output(i)];
|
||||
let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1;
|
||||
|
||||
constraints.push(output - computed_output);
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let const_0 = vars.local_constants[0];
|
||||
let const_1 = vars.local_constants[1];
|
||||
|
||||
let mut constraints = Vec::new();
|
||||
for i in 0..self.num_ops {
|
||||
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
|
||||
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
|
||||
let addend = vars.local_wires[Self::wire_ith_addend(i)];
|
||||
let output = vars.local_wires[Self::wire_ith_output(i)];
|
||||
let computed_output = multiplicand_0 * multiplicand_1 * const_0 + addend * const_1;
|
||||
|
||||
constraints.push(output - computed_output);
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
vars: EvaluationTargets<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let const_0 = vars.local_constants[0];
|
||||
let const_1 = vars.local_constants[1];
|
||||
|
||||
let mut constraints = Vec::new();
|
||||
for i in 0..self.num_ops {
|
||||
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
|
||||
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
|
||||
let addend = vars.local_wires[Self::wire_ith_addend(i)];
|
||||
let output = vars.local_wires[Self::wire_ith_output(i)];
|
||||
let computed_output = {
|
||||
let scaled_mul =
|
||||
builder.mul_many_extension(&[const_0, multiplicand_0, multiplicand_1]);
|
||||
builder.mul_add_extension(const_1, addend, scaled_mul)
|
||||
};
|
||||
|
||||
let diff = builder.sub_extension(output, computed_output);
|
||||
constraints.push(diff);
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn generators(
|
||||
&self,
|
||||
gate_index: usize,
|
||||
local_constants: &[F],
|
||||
) -> Vec<Box<dyn WitnessGenerator<F>>> {
|
||||
(0..self.num_ops)
|
||||
.map(|i| {
|
||||
let g: Box<dyn WitnessGenerator<F>> = Box::new(
|
||||
ArithmeticBaseGenerator {
|
||||
gate_index,
|
||||
const_0: local_constants[0],
|
||||
const_1: local_constants[1],
|
||||
i,
|
||||
}
|
||||
.adapter(),
|
||||
);
|
||||
g
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn num_wires(&self) -> usize {
|
||||
self.num_ops * 4
|
||||
}
|
||||
|
||||
fn num_constants(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn degree(&self) -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn num_constraints(&self) -> usize {
|
||||
self.num_ops
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ArithmeticBaseGenerator<F: RichField + Extendable<D>, const D: usize> {
|
||||
gate_index: usize,
|
||||
const_0: F,
|
||||
const_1: F,
|
||||
i: usize,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
|
||||
for ArithmeticBaseGenerator<F, D>
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
[
|
||||
ArithmeticGate::wire_ith_multiplicand_0(self.i),
|
||||
ArithmeticGate::wire_ith_multiplicand_1(self.i),
|
||||
ArithmeticGate::wire_ith_addend(self.i),
|
||||
]
|
||||
.iter()
|
||||
.map(|&i| Target::wire(self.gate_index, i))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
let get_wire =
|
||||
|wire: usize| -> F { witness.get_target(Target::wire(self.gate_index, wire)) };
|
||||
|
||||
let multiplicand_0 = get_wire(ArithmeticGate::wire_ith_multiplicand_0(self.i));
|
||||
let multiplicand_1 = get_wire(ArithmeticGate::wire_ith_multiplicand_1(self.i));
|
||||
let addend = get_wire(ArithmeticGate::wire_ith_addend(self.i));
|
||||
|
||||
let output_target = Target::wire(self.gate_index, ArithmeticGate::wire_ith_output(self.i));
|
||||
|
||||
let computed_output =
|
||||
multiplicand_0 * multiplicand_1 * self.const_0 + addend * self.const_1;
|
||||
|
||||
out_buffer.set_target(output_target, computed_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::gates::arithmetic_base::ArithmeticGate;
|
||||
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
|
||||
use crate::plonk::circuit_data::CircuitConfig;
|
||||
|
||||
#[test]
|
||||
fn low_degree() {
|
||||
let gate = ArithmeticGate::new_from_config(&CircuitConfig::standard_recursion_config());
|
||||
test_low_degree::<GoldilocksField, _, 4>(gate);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_fns() -> Result<()> {
|
||||
let gate = ArithmeticGate::new_from_config(&CircuitConfig::standard_recursion_config());
|
||||
test_eval_fns::<GoldilocksField, _, 4>(gate)
|
||||
}
|
||||
}
|
||||
@ -11,7 +11,8 @@ use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::circuit_data::CircuitConfig;
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
|
||||
/// A gate which can a linear combination `c0*x*y+c1*z` twice with the same `x`.
|
||||
/// A gate which can perform a weighted multiply-add, i.e. `result = c0 x y + c1 z`. If the config
|
||||
/// supports enough routed wires, it can support several such operations in one gate.
|
||||
#[derive(Debug)]
|
||||
pub struct ArithmeticExtensionGate<const D: usize> {
|
||||
/// Number of arithmetic operations performed by an arithmetic gate.
|
||||
@ -203,7 +204,7 @@ mod tests {
|
||||
use anyhow::Result;
|
||||
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::gates::arithmetic::ArithmeticExtensionGate;
|
||||
use crate::gates::arithmetic_extension::ArithmeticExtensionGate;
|
||||
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
|
||||
use crate::plonk::circuit_data::CircuitConfig;
|
||||
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
|
||||
@ -11,37 +11,49 @@ use crate::iop::target::Target;
|
||||
use crate::iop::wire::Wire;
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::circuit_data::CircuitConfig;
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
|
||||
/// Number of arithmetic operations performed by an arithmetic gate.
|
||||
pub const NUM_U32_ARITHMETIC_OPS: usize = 3;
|
||||
|
||||
/// A gate to perform a basic mul-add on 32-bit values (we assume they are range-checked beforehand).
|
||||
#[derive(Debug)]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct U32ArithmeticGate<F: Extendable<D>, const D: usize> {
|
||||
pub num_ops: usize,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: Extendable<D>, const D: usize> U32ArithmeticGate<F, D> {
|
||||
pub fn wire_ith_multiplicand_0(i: usize) -> usize {
|
||||
debug_assert!(i < NUM_U32_ARITHMETIC_OPS);
|
||||
pub fn new_from_config(config: &CircuitConfig) -> Self {
|
||||
Self {
|
||||
num_ops: Self::num_ops(config),
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn num_ops(config: &CircuitConfig) -> usize {
|
||||
let wires_per_op = 5 + Self::num_limbs();
|
||||
let routed_wires_per_op = 5;
|
||||
(config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op)
|
||||
}
|
||||
|
||||
pub fn wire_ith_multiplicand_0(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
5 * i
|
||||
}
|
||||
pub fn wire_ith_multiplicand_1(i: usize) -> usize {
|
||||
debug_assert!(i < NUM_U32_ARITHMETIC_OPS);
|
||||
pub fn wire_ith_multiplicand_1(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
5 * i + 1
|
||||
}
|
||||
pub fn wire_ith_addend(i: usize) -> usize {
|
||||
debug_assert!(i < NUM_U32_ARITHMETIC_OPS);
|
||||
pub fn wire_ith_addend(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
5 * i + 2
|
||||
}
|
||||
|
||||
pub fn wire_ith_output_low_half(i: usize) -> usize {
|
||||
debug_assert!(i < NUM_U32_ARITHMETIC_OPS);
|
||||
pub fn wire_ith_output_low_half(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
5 * i + 3
|
||||
}
|
||||
pub fn wire_ith_output_high_half(i: usize) -> usize {
|
||||
debug_assert!(i < NUM_U32_ARITHMETIC_OPS);
|
||||
pub fn wire_ith_output_high_half(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
5 * i + 4
|
||||
}
|
||||
|
||||
@ -52,10 +64,10 @@ impl<F: Extendable<D>, const D: usize> U32ArithmeticGate<F, D> {
|
||||
64 / Self::limb_bits()
|
||||
}
|
||||
|
||||
pub fn wire_ith_output_jth_limb(i: usize, j: usize) -> usize {
|
||||
debug_assert!(i < NUM_U32_ARITHMETIC_OPS);
|
||||
pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
debug_assert!(j < Self::num_limbs());
|
||||
5 * NUM_U32_ARITHMETIC_OPS + Self::num_limbs() * i + j
|
||||
5 * self.num_ops + Self::num_limbs() * i + j
|
||||
}
|
||||
}
|
||||
|
||||
@ -66,15 +78,15 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
|
||||
|
||||
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
for i in 0..NUM_U32_ARITHMETIC_OPS {
|
||||
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
|
||||
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
|
||||
let addend = vars.local_wires[Self::wire_ith_addend(i)];
|
||||
for i in 0..self.num_ops {
|
||||
let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)];
|
||||
let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)];
|
||||
let addend = vars.local_wires[self.wire_ith_addend(i)];
|
||||
|
||||
let computed_output = multiplicand_0 * multiplicand_1 + addend;
|
||||
|
||||
let output_low = vars.local_wires[Self::wire_ith_output_low_half(i)];
|
||||
let output_high = vars.local_wires[Self::wire_ith_output_high_half(i)];
|
||||
let output_low = vars.local_wires[self.wire_ith_output_low_half(i)];
|
||||
let output_high = vars.local_wires[self.wire_ith_output_high_half(i)];
|
||||
|
||||
let base = F::Extension::from_canonical_u64(1 << 32u64);
|
||||
let combined_output = output_high * base + output_low;
|
||||
@ -86,7 +98,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
|
||||
let midpoint = Self::num_limbs() / 2;
|
||||
let base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits());
|
||||
for j in (0..Self::num_limbs()).rev() {
|
||||
let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)];
|
||||
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
|
||||
let max_limb = 1 << Self::limb_bits();
|
||||
let product = (0..max_limb)
|
||||
.map(|x| this_limb - F::Extension::from_canonical_usize(x))
|
||||
@ -108,15 +120,15 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
for i in 0..NUM_U32_ARITHMETIC_OPS {
|
||||
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
|
||||
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
|
||||
let addend = vars.local_wires[Self::wire_ith_addend(i)];
|
||||
for i in 0..self.num_ops {
|
||||
let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)];
|
||||
let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)];
|
||||
let addend = vars.local_wires[self.wire_ith_addend(i)];
|
||||
|
||||
let computed_output = multiplicand_0 * multiplicand_1 + addend;
|
||||
|
||||
let output_low = vars.local_wires[Self::wire_ith_output_low_half(i)];
|
||||
let output_high = vars.local_wires[Self::wire_ith_output_high_half(i)];
|
||||
let output_low = vars.local_wires[self.wire_ith_output_low_half(i)];
|
||||
let output_high = vars.local_wires[self.wire_ith_output_high_half(i)];
|
||||
|
||||
let base = F::from_canonical_u64(1 << 32u64);
|
||||
let combined_output = output_high * base + output_low;
|
||||
@ -128,7 +140,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
|
||||
let midpoint = Self::num_limbs() / 2;
|
||||
let base = F::from_canonical_u64(1u64 << Self::limb_bits());
|
||||
for j in (0..Self::num_limbs()).rev() {
|
||||
let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)];
|
||||
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
|
||||
let max_limb = 1 << Self::limb_bits();
|
||||
let product = (0..max_limb)
|
||||
.map(|x| this_limb - F::from_canonical_usize(x))
|
||||
@ -155,15 +167,15 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
for i in 0..NUM_U32_ARITHMETIC_OPS {
|
||||
let multiplicand_0 = vars.local_wires[Self::wire_ith_multiplicand_0(i)];
|
||||
let multiplicand_1 = vars.local_wires[Self::wire_ith_multiplicand_1(i)];
|
||||
let addend = vars.local_wires[Self::wire_ith_addend(i)];
|
||||
for i in 0..self.num_ops {
|
||||
let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)];
|
||||
let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)];
|
||||
let addend = vars.local_wires[self.wire_ith_addend(i)];
|
||||
|
||||
let computed_output = builder.mul_add_extension(multiplicand_0, multiplicand_1, addend);
|
||||
|
||||
let output_low = vars.local_wires[Self::wire_ith_output_low_half(i)];
|
||||
let output_high = vars.local_wires[Self::wire_ith_output_high_half(i)];
|
||||
let output_low = vars.local_wires[self.wire_ith_output_low_half(i)];
|
||||
let output_high = vars.local_wires[self.wire_ith_output_high_half(i)];
|
||||
|
||||
let base: F::Extension = F::from_canonical_u64(1 << 32u64).into();
|
||||
let base_target = builder.constant_extension(base);
|
||||
@ -177,7 +189,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
|
||||
let base = builder
|
||||
.constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits()));
|
||||
for j in (0..Self::num_limbs()).rev() {
|
||||
let this_limb = vars.local_wires[Self::wire_ith_output_jth_limb(i, j)];
|
||||
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
|
||||
let max_limb = 1 << Self::limb_bits();
|
||||
|
||||
let mut product = builder.one_extension();
|
||||
@ -210,10 +222,11 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
|
||||
gate_index: usize,
|
||||
_local_constants: &[F],
|
||||
) -> Vec<Box<dyn WitnessGenerator<F>>> {
|
||||
(0..NUM_U32_ARITHMETIC_OPS)
|
||||
(0..self.num_ops)
|
||||
.map(|i| {
|
||||
let g: Box<dyn WitnessGenerator<F>> = Box::new(
|
||||
U32ArithmeticGenerator {
|
||||
gate: *self,
|
||||
gate_index,
|
||||
i,
|
||||
_phantom: PhantomData,
|
||||
@ -226,7 +239,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
|
||||
}
|
||||
|
||||
fn num_wires(&self) -> usize {
|
||||
NUM_U32_ARITHMETIC_OPS * (5 + Self::num_limbs())
|
||||
self.num_ops * (5 + Self::num_limbs())
|
||||
}
|
||||
|
||||
fn num_constants(&self) -> usize {
|
||||
@ -238,12 +251,13 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
|
||||
}
|
||||
|
||||
fn num_constraints(&self) -> usize {
|
||||
NUM_U32_ARITHMETIC_OPS * (3 + Self::num_limbs())
|
||||
self.num_ops * (3 + Self::num_limbs())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct U32ArithmeticGenerator<F: Extendable<D>, const D: usize> {
|
||||
gate: U32ArithmeticGate<F, D>,
|
||||
gate_index: usize,
|
||||
i: usize,
|
||||
_phantom: PhantomData<F>,
|
||||
@ -253,17 +267,11 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for U32ArithmeticGener
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
let local_target = |input| Target::wire(self.gate_index, input);
|
||||
|
||||
let mut deps = Vec::with_capacity(3);
|
||||
deps.push(local_target(
|
||||
U32ArithmeticGate::<F, D>::wire_ith_multiplicand_0(self.i),
|
||||
));
|
||||
deps.push(local_target(
|
||||
U32ArithmeticGate::<F, D>::wire_ith_multiplicand_1(self.i),
|
||||
));
|
||||
deps.push(local_target(U32ArithmeticGate::<F, D>::wire_ith_addend(
|
||||
self.i,
|
||||
)));
|
||||
deps
|
||||
vec![
|
||||
local_target(self.gate.wire_ith_multiplicand_0(self.i)),
|
||||
local_target(self.gate.wire_ith_multiplicand_1(self.i)),
|
||||
local_target(self.gate.wire_ith_addend(self.i)),
|
||||
]
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
@ -274,11 +282,9 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for U32ArithmeticGener
|
||||
|
||||
let get_local_wire = |input| witness.get_wire(local_wire(input));
|
||||
|
||||
let multiplicand_0 =
|
||||
get_local_wire(U32ArithmeticGate::<F, D>::wire_ith_multiplicand_0(self.i));
|
||||
let multiplicand_1 =
|
||||
get_local_wire(U32ArithmeticGate::<F, D>::wire_ith_multiplicand_1(self.i));
|
||||
let addend = get_local_wire(U32ArithmeticGate::<F, D>::wire_ith_addend(self.i));
|
||||
let multiplicand_0 = get_local_wire(self.gate.wire_ith_multiplicand_0(self.i));
|
||||
let multiplicand_1 = get_local_wire(self.gate.wire_ith_multiplicand_1(self.i));
|
||||
let addend = get_local_wire(self.gate.wire_ith_addend(self.i));
|
||||
|
||||
let output = multiplicand_0 * multiplicand_1 + addend;
|
||||
let mut output_u64 = output.to_canonical_u64();
|
||||
@ -289,34 +295,25 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for U32ArithmeticGener
|
||||
let output_high = F::from_canonical_u64(output_high_u64);
|
||||
let output_low = F::from_canonical_u64(output_low_u64);
|
||||
|
||||
let output_high_wire =
|
||||
local_wire(U32ArithmeticGate::<F, D>::wire_ith_output_high_half(self.i));
|
||||
let output_low_wire =
|
||||
local_wire(U32ArithmeticGate::<F, D>::wire_ith_output_low_half(self.i));
|
||||
let output_high_wire = local_wire(self.gate.wire_ith_output_high_half(self.i));
|
||||
let output_low_wire = local_wire(self.gate.wire_ith_output_low_half(self.i));
|
||||
|
||||
out_buffer.set_wire(output_high_wire, output_high);
|
||||
out_buffer.set_wire(output_low_wire, output_low);
|
||||
|
||||
let num_limbs = U32ArithmeticGate::<F, D>::num_limbs();
|
||||
let limb_base = 1 << U32ArithmeticGate::<F, D>::limb_bits();
|
||||
let output_limbs_u64: Vec<_> = unfold((), move |_| {
|
||||
let output_limbs_u64 = unfold((), move |_| {
|
||||
let ret = output_u64 % limb_base;
|
||||
output_u64 /= limb_base;
|
||||
Some(ret)
|
||||
})
|
||||
.take(num_limbs)
|
||||
.collect();
|
||||
let output_limbs_f: Vec<_> = output_limbs_u64
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(F::from_canonical_u64)
|
||||
.collect();
|
||||
.take(num_limbs);
|
||||
let output_limbs_f = output_limbs_u64.map(F::from_canonical_u64);
|
||||
|
||||
for j in 0..num_limbs {
|
||||
let wire = local_wire(U32ArithmeticGate::<F, D>::wire_ith_output_jth_limb(
|
||||
self.i, j,
|
||||
));
|
||||
out_buffer.set_wire(wire, output_limbs_f[j]);
|
||||
for (j, output_limb) in output_limbs_f.enumerate() {
|
||||
let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j));
|
||||
out_buffer.set_wire(wire, output_limb);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -330,7 +327,7 @@ mod tests {
|
||||
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::gates::arithmetic_u32::{U32ArithmeticGate, NUM_U32_ARITHMETIC_OPS};
|
||||
use crate::gates::arithmetic_u32::U32ArithmeticGate;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
|
||||
use crate::hash::hash_types::HashOut;
|
||||
@ -340,16 +337,15 @@ mod tests {
|
||||
#[test]
|
||||
fn low_degree() {
|
||||
test_low_degree::<GoldilocksField, _, 4>(U32ArithmeticGate::<GoldilocksField, 4> {
|
||||
num_ops: 3,
|
||||
_phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_fns() -> Result<()> {
|
||||
const D: usize = 2;
|
||||
type C = PoseidonGoldilocksConfig;
|
||||
type F = <C as GenericConfig<D>>::F;
|
||||
test_eval_fns::<F, C, _, D>(U32ArithmeticGate::<F, D> {
|
||||
test_eval_fns::<GoldilocksField, _, 4>(U32ArithmeticGate::<GoldilocksField, 4> {
|
||||
num_ops: 3,
|
||||
_phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
@ -360,6 +356,7 @@ mod tests {
|
||||
type C = PoseidonGoldilocksConfig;
|
||||
type F = <C as GenericConfig<D>>::F;
|
||||
type FF = <C as GenericConfig<D>>::FE;
|
||||
const NUM_U32_ARITHMETIC_OPS: usize = 3;
|
||||
|
||||
fn get_wires(
|
||||
multiplicands_0: Vec<u64>,
|
||||
@ -387,8 +384,7 @@ mod tests {
|
||||
output /= limb_base;
|
||||
}
|
||||
let mut output_limbs_f: Vec<_> = output_limbs
|
||||
.iter()
|
||||
.cloned()
|
||||
.into_iter()
|
||||
.map(F::from_canonical_u64)
|
||||
.collect();
|
||||
|
||||
@ -418,6 +414,7 @@ mod tests {
|
||||
.collect();
|
||||
|
||||
let gate = U32ArithmeticGate::<F, D> {
|
||||
num_ops: NUM_U32_ARITHMETIC_OPS,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
|
||||
|
||||
607
src/gates/assert_le.rs
Normal file
607
src/gates/assert_le.rs
Normal file
@ -0,0 +1,607 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::{Field, PrimeField, RichField};
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::wire::Wire;
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive};
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
use crate::util::{bits_u64, ceil_div_usize};
|
||||
|
||||
// TODO: replace/merge this gate with `ComparisonGate`.
|
||||
|
||||
/// A gate for checking that one value is less than or equal to another.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AssertLessThanGate<F: PrimeField + Extendable<D>, const D: usize> {
|
||||
pub(crate) num_bits: usize,
|
||||
pub(crate) num_chunks: usize,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> AssertLessThanGate<F, D> {
|
||||
pub fn new(num_bits: usize, num_chunks: usize) -> Self {
|
||||
debug_assert!(num_bits < bits_u64(F::ORDER));
|
||||
Self {
|
||||
num_bits,
|
||||
num_chunks,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn chunk_bits(&self) -> usize {
|
||||
ceil_div_usize(self.num_bits, self.num_chunks)
|
||||
}
|
||||
|
||||
pub fn wire_first_input(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
pub fn wire_second_input(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
pub fn wire_most_significant_diff(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
pub fn wire_first_chunk_val(&self, chunk: usize) -> usize {
|
||||
debug_assert!(chunk < self.num_chunks);
|
||||
3 + chunk
|
||||
}
|
||||
|
||||
pub fn wire_second_chunk_val(&self, chunk: usize) -> usize {
|
||||
debug_assert!(chunk < self.num_chunks);
|
||||
3 + self.num_chunks + chunk
|
||||
}
|
||||
|
||||
pub fn wire_equality_dummy(&self, chunk: usize) -> usize {
|
||||
debug_assert!(chunk < self.num_chunks);
|
||||
3 + 2 * self.num_chunks + chunk
|
||||
}
|
||||
|
||||
pub fn wire_chunks_equal(&self, chunk: usize) -> usize {
|
||||
debug_assert!(chunk < self.num_chunks);
|
||||
3 + 3 * self.num_chunks + chunk
|
||||
}
|
||||
|
||||
pub fn wire_intermediate_value(&self, chunk: usize) -> usize {
|
||||
debug_assert!(chunk < self.num_chunks);
|
||||
3 + 4 * self.num_chunks + chunk
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for AssertLessThanGate<F, D> {
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}<D={}>", self, D)
|
||||
}
|
||||
|
||||
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
let first_input = vars.local_wires[self.wire_first_input()];
|
||||
let second_input = vars.local_wires[self.wire_second_input()];
|
||||
|
||||
// Get chunks and assert that they match
|
||||
let first_chunks: Vec<F::Extension> = (0..self.num_chunks)
|
||||
.map(|i| vars.local_wires[self.wire_first_chunk_val(i)])
|
||||
.collect();
|
||||
let second_chunks: Vec<F::Extension> = (0..self.num_chunks)
|
||||
.map(|i| vars.local_wires[self.wire_second_chunk_val(i)])
|
||||
.collect();
|
||||
|
||||
let first_chunks_combined = reduce_with_powers(
|
||||
&first_chunks,
|
||||
F::Extension::from_canonical_usize(1 << self.chunk_bits()),
|
||||
);
|
||||
let second_chunks_combined = reduce_with_powers(
|
||||
&second_chunks,
|
||||
F::Extension::from_canonical_usize(1 << self.chunk_bits()),
|
||||
);
|
||||
|
||||
constraints.push(first_chunks_combined - first_input);
|
||||
constraints.push(second_chunks_combined - second_input);
|
||||
|
||||
let chunk_size = 1 << self.chunk_bits();
|
||||
|
||||
let mut most_significant_diff_so_far = F::Extension::ZERO;
|
||||
|
||||
for i in 0..self.num_chunks {
|
||||
// Range-check the chunks to be less than `chunk_size`.
|
||||
let first_product = (0..chunk_size)
|
||||
.map(|x| first_chunks[i] - F::Extension::from_canonical_usize(x))
|
||||
.product();
|
||||
let second_product = (0..chunk_size)
|
||||
.map(|x| second_chunks[i] - F::Extension::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(first_product);
|
||||
constraints.push(second_product);
|
||||
|
||||
let difference = second_chunks[i] - first_chunks[i];
|
||||
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
|
||||
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
|
||||
|
||||
// Two constraints to assert that `chunks_equal` is valid.
|
||||
constraints.push(difference * equality_dummy - (F::Extension::ONE - chunks_equal));
|
||||
constraints.push(chunks_equal * difference);
|
||||
|
||||
// Update `most_significant_diff_so_far`.
|
||||
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
|
||||
constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far);
|
||||
most_significant_diff_so_far =
|
||||
intermediate_value + (F::Extension::ONE - chunks_equal) * difference;
|
||||
}
|
||||
|
||||
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
|
||||
constraints.push(most_significant_diff - most_significant_diff_so_far);
|
||||
|
||||
// Range check `most_significant_diff` to be less than `chunk_size`.
|
||||
let product = (0..chunk_size)
|
||||
.map(|x| most_significant_diff - F::Extension::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(product);
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
let first_input = vars.local_wires[self.wire_first_input()];
|
||||
let second_input = vars.local_wires[self.wire_second_input()];
|
||||
|
||||
// Get chunks and assert that they match
|
||||
let first_chunks: Vec<F> = (0..self.num_chunks)
|
||||
.map(|i| vars.local_wires[self.wire_first_chunk_val(i)])
|
||||
.collect();
|
||||
let second_chunks: Vec<F> = (0..self.num_chunks)
|
||||
.map(|i| vars.local_wires[self.wire_second_chunk_val(i)])
|
||||
.collect();
|
||||
|
||||
let first_chunks_combined = reduce_with_powers(
|
||||
&first_chunks,
|
||||
F::from_canonical_usize(1 << self.chunk_bits()),
|
||||
);
|
||||
let second_chunks_combined = reduce_with_powers(
|
||||
&second_chunks,
|
||||
F::from_canonical_usize(1 << self.chunk_bits()),
|
||||
);
|
||||
|
||||
constraints.push(first_chunks_combined - first_input);
|
||||
constraints.push(second_chunks_combined - second_input);
|
||||
|
||||
let chunk_size = 1 << self.chunk_bits();
|
||||
|
||||
let mut most_significant_diff_so_far = F::ZERO;
|
||||
|
||||
for i in 0..self.num_chunks {
|
||||
// Range-check the chunks to be less than `chunk_size`.
|
||||
let first_product = (0..chunk_size)
|
||||
.map(|x| first_chunks[i] - F::from_canonical_usize(x))
|
||||
.product();
|
||||
let second_product = (0..chunk_size)
|
||||
.map(|x| second_chunks[i] - F::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(first_product);
|
||||
constraints.push(second_product);
|
||||
|
||||
let difference = second_chunks[i] - first_chunks[i];
|
||||
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
|
||||
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
|
||||
|
||||
// Two constraints to assert that `chunks_equal` is valid.
|
||||
constraints.push(difference * equality_dummy - (F::ONE - chunks_equal));
|
||||
constraints.push(chunks_equal * difference);
|
||||
|
||||
// Update `most_significant_diff_so_far`.
|
||||
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
|
||||
constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far);
|
||||
most_significant_diff_so_far =
|
||||
intermediate_value + (F::ONE - chunks_equal) * difference;
|
||||
}
|
||||
|
||||
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
|
||||
constraints.push(most_significant_diff - most_significant_diff_so_far);
|
||||
|
||||
// Range check `most_significant_diff` to be less than `chunk_size`.
|
||||
let product = (0..chunk_size)
|
||||
.map(|x| most_significant_diff - F::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(product);
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
vars: EvaluationTargets<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
let first_input = vars.local_wires[self.wire_first_input()];
|
||||
let second_input = vars.local_wires[self.wire_second_input()];
|
||||
|
||||
// Get chunks and assert that they match
|
||||
let first_chunks: Vec<ExtensionTarget<D>> = (0..self.num_chunks)
|
||||
.map(|i| vars.local_wires[self.wire_first_chunk_val(i)])
|
||||
.collect();
|
||||
let second_chunks: Vec<ExtensionTarget<D>> = (0..self.num_chunks)
|
||||
.map(|i| vars.local_wires[self.wire_second_chunk_val(i)])
|
||||
.collect();
|
||||
|
||||
let chunk_base = builder.constant(F::from_canonical_usize(1 << self.chunk_bits()));
|
||||
let first_chunks_combined =
|
||||
reduce_with_powers_ext_recursive(builder, &first_chunks, chunk_base);
|
||||
let second_chunks_combined =
|
||||
reduce_with_powers_ext_recursive(builder, &second_chunks, chunk_base);
|
||||
|
||||
constraints.push(builder.sub_extension(first_chunks_combined, first_input));
|
||||
constraints.push(builder.sub_extension(second_chunks_combined, second_input));
|
||||
|
||||
let chunk_size = 1 << self.chunk_bits();
|
||||
|
||||
let mut most_significant_diff_so_far = builder.zero_extension();
|
||||
|
||||
let one = builder.one_extension();
|
||||
// Find the chosen chunk.
|
||||
for i in 0..self.num_chunks {
|
||||
// Range-check the chunks to be less than `chunk_size`.
|
||||
let mut first_product = one;
|
||||
let mut second_product = one;
|
||||
for x in 0..chunk_size {
|
||||
let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x));
|
||||
let first_diff = builder.sub_extension(first_chunks[i], x_f);
|
||||
let second_diff = builder.sub_extension(second_chunks[i], x_f);
|
||||
first_product = builder.mul_extension(first_product, first_diff);
|
||||
second_product = builder.mul_extension(second_product, second_diff);
|
||||
}
|
||||
constraints.push(first_product);
|
||||
constraints.push(second_product);
|
||||
|
||||
let difference = builder.sub_extension(second_chunks[i], first_chunks[i]);
|
||||
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
|
||||
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
|
||||
|
||||
// Two constraints to assert that `chunks_equal` is valid.
|
||||
let diff_times_equal = builder.mul_extension(difference, equality_dummy);
|
||||
let not_equal = builder.sub_extension(one, chunks_equal);
|
||||
constraints.push(builder.sub_extension(diff_times_equal, not_equal));
|
||||
constraints.push(builder.mul_extension(chunks_equal, difference));
|
||||
|
||||
// Update `most_significant_diff_so_far`.
|
||||
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
|
||||
let old_diff = builder.mul_extension(chunks_equal, most_significant_diff_so_far);
|
||||
constraints.push(builder.sub_extension(intermediate_value, old_diff));
|
||||
|
||||
let not_equal = builder.sub_extension(one, chunks_equal);
|
||||
let new_diff = builder.mul_extension(not_equal, difference);
|
||||
most_significant_diff_so_far = builder.add_extension(intermediate_value, new_diff);
|
||||
}
|
||||
|
||||
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
|
||||
constraints
|
||||
.push(builder.sub_extension(most_significant_diff, most_significant_diff_so_far));
|
||||
|
||||
// Range check `most_significant_diff` to be less than `chunk_size`.
|
||||
let mut product = builder.one_extension();
|
||||
for x in 0..chunk_size {
|
||||
let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x));
|
||||
let diff = builder.sub_extension(most_significant_diff, x_f);
|
||||
product = builder.mul_extension(product, diff);
|
||||
}
|
||||
constraints.push(product);
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn generators(
|
||||
&self,
|
||||
gate_index: usize,
|
||||
_local_constants: &[F],
|
||||
) -> Vec<Box<dyn WitnessGenerator<F>>> {
|
||||
let gen = AssertLessThanGenerator::<F, D> {
|
||||
gate_index,
|
||||
gate: self.clone(),
|
||||
};
|
||||
vec![Box::new(gen.adapter())]
|
||||
}
|
||||
|
||||
fn num_wires(&self) -> usize {
|
||||
self.wire_intermediate_value(self.num_chunks - 1) + 1
|
||||
}
|
||||
|
||||
fn num_constants(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn degree(&self) -> usize {
|
||||
1 << self.chunk_bits()
|
||||
}
|
||||
|
||||
fn num_constraints(&self) -> usize {
|
||||
4 + 5 * self.num_chunks
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AssertLessThanGenerator<F: RichField + Extendable<D>, const D: usize> {
|
||||
gate_index: usize,
|
||||
gate: AssertLessThanGate<F, D>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
|
||||
for AssertLessThanGenerator<F, D>
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
let local_target = |input| Target::wire(self.gate_index, input);
|
||||
|
||||
vec![
|
||||
local_target(self.gate.wire_first_input()),
|
||||
local_target(self.gate.wire_second_input()),
|
||||
]
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
let local_wire = |input| Wire {
|
||||
gate: self.gate_index,
|
||||
input,
|
||||
};
|
||||
|
||||
let get_local_wire = |input| witness.get_wire(local_wire(input));
|
||||
|
||||
let first_input = get_local_wire(self.gate.wire_first_input());
|
||||
let second_input = get_local_wire(self.gate.wire_second_input());
|
||||
|
||||
let first_input_u64 = first_input.to_canonical_u64();
|
||||
let second_input_u64 = second_input.to_canonical_u64();
|
||||
|
||||
debug_assert!(first_input_u64 < second_input_u64);
|
||||
|
||||
let chunk_size = 1 << self.gate.chunk_bits();
|
||||
let first_input_chunks: Vec<F> = (0..self.gate.num_chunks)
|
||||
.scan(first_input_u64, |acc, _| {
|
||||
let tmp = *acc % chunk_size;
|
||||
*acc /= chunk_size;
|
||||
Some(F::from_canonical_u64(tmp))
|
||||
})
|
||||
.collect();
|
||||
let second_input_chunks: Vec<F> = (0..self.gate.num_chunks)
|
||||
.scan(second_input_u64, |acc, _| {
|
||||
let tmp = *acc % chunk_size;
|
||||
*acc /= chunk_size;
|
||||
Some(F::from_canonical_u64(tmp))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let chunks_equal: Vec<F> = (0..self.gate.num_chunks)
|
||||
.map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i]))
|
||||
.collect();
|
||||
let equality_dummies: Vec<F> = first_input_chunks
|
||||
.iter()
|
||||
.zip(second_input_chunks.iter())
|
||||
.map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) })
|
||||
.collect();
|
||||
|
||||
let mut most_significant_diff_so_far = F::ZERO;
|
||||
let mut intermediate_values = Vec::new();
|
||||
for i in 0..self.gate.num_chunks {
|
||||
if first_input_chunks[i] != second_input_chunks[i] {
|
||||
most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i];
|
||||
intermediate_values.push(F::ZERO);
|
||||
} else {
|
||||
intermediate_values.push(most_significant_diff_so_far);
|
||||
}
|
||||
}
|
||||
let most_significant_diff = most_significant_diff_so_far;
|
||||
|
||||
out_buffer.set_wire(
|
||||
local_wire(self.gate.wire_most_significant_diff()),
|
||||
most_significant_diff,
|
||||
);
|
||||
for i in 0..self.gate.num_chunks {
|
||||
out_buffer.set_wire(
|
||||
local_wire(self.gate.wire_first_chunk_val(i)),
|
||||
first_input_chunks[i],
|
||||
);
|
||||
out_buffer.set_wire(
|
||||
local_wire(self.gate.wire_second_chunk_val(i)),
|
||||
second_input_chunks[i],
|
||||
);
|
||||
out_buffer.set_wire(
|
||||
local_wire(self.gate.wire_equality_dummy(i)),
|
||||
equality_dummies[i],
|
||||
);
|
||||
out_buffer.set_wire(local_wire(self.gate.wire_chunks_equal(i)), chunks_equal[i]);
|
||||
out_buffer.set_wire(
|
||||
local_wire(self.gate.wire_intermediate_value(i)),
|
||||
intermediate_values[i],
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use anyhow::Result;
|
||||
use rand::Rng;
|
||||
|
||||
use crate::field::extension_field::quartic::QuarticExtension;
|
||||
use crate::field::field_types::{Field, PrimeField};
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::gates::assert_le::AssertLessThanGate;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
|
||||
use crate::hash::hash_types::HashOut;
|
||||
use crate::plonk::vars::EvaluationVars;
|
||||
|
||||
#[test]
|
||||
fn wire_indices() {
|
||||
type AG = AssertLessThanGate<GoldilocksField, 4>;
|
||||
let num_bits = 40;
|
||||
let num_chunks = 5;
|
||||
|
||||
let gate = AG {
|
||||
num_bits,
|
||||
num_chunks,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
|
||||
assert_eq!(gate.wire_first_input(), 0);
|
||||
assert_eq!(gate.wire_second_input(), 1);
|
||||
assert_eq!(gate.wire_most_significant_diff(), 2);
|
||||
assert_eq!(gate.wire_first_chunk_val(0), 3);
|
||||
assert_eq!(gate.wire_first_chunk_val(4), 7);
|
||||
assert_eq!(gate.wire_second_chunk_val(0), 8);
|
||||
assert_eq!(gate.wire_second_chunk_val(4), 12);
|
||||
assert_eq!(gate.wire_equality_dummy(0), 13);
|
||||
assert_eq!(gate.wire_equality_dummy(4), 17);
|
||||
assert_eq!(gate.wire_chunks_equal(0), 18);
|
||||
assert_eq!(gate.wire_chunks_equal(4), 22);
|
||||
assert_eq!(gate.wire_intermediate_value(0), 23);
|
||||
assert_eq!(gate.wire_intermediate_value(4), 27);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn low_degree() {
|
||||
let num_bits = 20;
|
||||
let num_chunks = 4;
|
||||
|
||||
test_low_degree::<GoldilocksField, _, 4>(AssertLessThanGate::<_, 4>::new(
|
||||
num_bits, num_chunks,
|
||||
))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_fns() -> Result<()> {
|
||||
let num_bits = 20;
|
||||
let num_chunks = 4;
|
||||
|
||||
test_eval_fns::<GoldilocksField, _, 4>(AssertLessThanGate::<_, 4>::new(
|
||||
num_bits, num_chunks,
|
||||
))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gate_constraint() {
|
||||
type F = GoldilocksField;
|
||||
type FF = QuarticExtension<GoldilocksField>;
|
||||
const D: usize = 4;
|
||||
|
||||
let num_bits = 40;
|
||||
let num_chunks = 5;
|
||||
let chunk_bits = num_bits / num_chunks;
|
||||
|
||||
// Returns the local wires for an AssertLessThanGate given the two inputs.
|
||||
let get_wires = |first_input: F, second_input: F| -> Vec<FF> {
|
||||
let mut v = Vec::new();
|
||||
|
||||
let first_input_u64 = first_input.to_canonical_u64();
|
||||
let second_input_u64 = second_input.to_canonical_u64();
|
||||
|
||||
let chunk_size = 1 << chunk_bits;
|
||||
let mut first_input_chunks: Vec<F> = (0..num_chunks)
|
||||
.scan(first_input_u64, |acc, _| {
|
||||
let tmp = *acc % chunk_size;
|
||||
*acc /= chunk_size;
|
||||
Some(F::from_canonical_u64(tmp))
|
||||
})
|
||||
.collect();
|
||||
let mut second_input_chunks: Vec<F> = (0..num_chunks)
|
||||
.scan(second_input_u64, |acc, _| {
|
||||
let tmp = *acc % chunk_size;
|
||||
*acc /= chunk_size;
|
||||
Some(F::from_canonical_u64(tmp))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut chunks_equal: Vec<F> = (0..num_chunks)
|
||||
.map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i]))
|
||||
.collect();
|
||||
let mut equality_dummies: Vec<F> = first_input_chunks
|
||||
.iter()
|
||||
.zip(second_input_chunks.iter())
|
||||
.map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) })
|
||||
.collect();
|
||||
|
||||
let mut most_significant_diff_so_far = F::ZERO;
|
||||
let mut intermediate_values = Vec::new();
|
||||
for i in 0..num_chunks {
|
||||
if first_input_chunks[i] != second_input_chunks[i] {
|
||||
most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i];
|
||||
intermediate_values.push(F::ZERO);
|
||||
} else {
|
||||
intermediate_values.push(most_significant_diff_so_far);
|
||||
}
|
||||
}
|
||||
let most_significant_diff = most_significant_diff_so_far;
|
||||
|
||||
v.push(first_input);
|
||||
v.push(second_input);
|
||||
v.push(most_significant_diff);
|
||||
v.append(&mut first_input_chunks);
|
||||
v.append(&mut second_input_chunks);
|
||||
v.append(&mut equality_dummies);
|
||||
v.append(&mut chunks_equal);
|
||||
v.append(&mut intermediate_values);
|
||||
|
||||
v.iter().map(|&x| x.into()).collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let max: u64 = 1 << (num_bits - 1);
|
||||
let first_input_u64 = rng.gen_range(0..max);
|
||||
let second_input_u64 = {
|
||||
let mut val = rng.gen_range(0..max);
|
||||
while val < first_input_u64 {
|
||||
val = rng.gen_range(0..max);
|
||||
}
|
||||
val
|
||||
};
|
||||
|
||||
let first_input = F::from_canonical_u64(first_input_u64);
|
||||
let second_input = F::from_canonical_u64(second_input_u64);
|
||||
|
||||
let less_than_gate = AssertLessThanGate::<F, D> {
|
||||
num_bits,
|
||||
num_chunks,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
let less_than_vars = EvaluationVars {
|
||||
local_constants: &[],
|
||||
local_wires: &get_wires(first_input, second_input),
|
||||
public_inputs_hash: &HashOut::rand(),
|
||||
};
|
||||
assert!(
|
||||
less_than_gate
|
||||
.eval_unfiltered(less_than_vars)
|
||||
.iter()
|
||||
.all(|x| x.is_zero()),
|
||||
"Gate constraints are not satisfied."
|
||||
);
|
||||
|
||||
let equal_gate = AssertLessThanGate::<F, D> {
|
||||
num_bits,
|
||||
num_chunks,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
let equal_vars = EvaluationVars {
|
||||
local_constants: &[],
|
||||
local_wires: &get_wires(first_input, first_input),
|
||||
public_inputs_hash: &HashOut::rand(),
|
||||
};
|
||||
assert!(
|
||||
equal_gate
|
||||
.eval_unfiltered(equal_vars)
|
||||
.iter()
|
||||
.all(|x| x.is_zero()),
|
||||
"Gate constraints are not satisfied."
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -24,8 +24,7 @@ impl<const B: usize> BaseSumGate<B> {
|
||||
}
|
||||
|
||||
pub fn new_from_config<F: PrimeField>(config: &CircuitConfig) -> Self {
|
||||
let num_limbs = ((F::ORDER as f64).log(B as f64).floor() as usize)
|
||||
.min(config.num_routed_wires - Self::START_LIMBS);
|
||||
let num_limbs = F::BITS.min(config.num_routed_wires - Self::START_LIMBS);
|
||||
Self::new(num_limbs)
|
||||
}
|
||||
|
||||
|
||||
@ -43,33 +43,42 @@ impl<F: Extendable<D>, const D: usize> ComparisonGate<F, D> {
|
||||
1
|
||||
}
|
||||
|
||||
pub fn wire_most_significant_diff(&self) -> usize {
|
||||
pub fn wire_result_bool(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
pub fn wire_most_significant_diff(&self) -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
pub fn wire_first_chunk_val(&self, chunk: usize) -> usize {
|
||||
debug_assert!(chunk < self.num_chunks);
|
||||
3 + chunk
|
||||
4 + chunk
|
||||
}
|
||||
|
||||
pub fn wire_second_chunk_val(&self, chunk: usize) -> usize {
|
||||
debug_assert!(chunk < self.num_chunks);
|
||||
3 + self.num_chunks + chunk
|
||||
4 + self.num_chunks + chunk
|
||||
}
|
||||
|
||||
pub fn wire_equality_dummy(&self, chunk: usize) -> usize {
|
||||
debug_assert!(chunk < self.num_chunks);
|
||||
3 + 2 * self.num_chunks + chunk
|
||||
4 + 2 * self.num_chunks + chunk
|
||||
}
|
||||
|
||||
pub fn wire_chunks_equal(&self, chunk: usize) -> usize {
|
||||
debug_assert!(chunk < self.num_chunks);
|
||||
3 + 3 * self.num_chunks + chunk
|
||||
4 + 3 * self.num_chunks + chunk
|
||||
}
|
||||
|
||||
pub fn wire_intermediate_value(&self, chunk: usize) -> usize {
|
||||
debug_assert!(chunk < self.num_chunks);
|
||||
3 + 4 * self.num_chunks + chunk
|
||||
4 + 4 * self.num_chunks + chunk
|
||||
}
|
||||
|
||||
/// The `bit_index`th bit of 2^n - 1 + most_significant_diff.
|
||||
pub fn wire_most_significant_diff_bit(&self, bit_index: usize) -> usize {
|
||||
4 + 5 * self.num_chunks + bit_index
|
||||
}
|
||||
}
|
||||
|
||||
@ -110,10 +119,10 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
|
||||
|
||||
for i in 0..self.num_chunks {
|
||||
// Range-check the chunks to be less than `chunk_size`.
|
||||
let first_product = (0..chunk_size)
|
||||
let first_product: F::Extension = (0..chunk_size)
|
||||
.map(|x| first_chunks[i] - F::Extension::from_canonical_usize(x))
|
||||
.product();
|
||||
let second_product = (0..chunk_size)
|
||||
let second_product: F::Extension = (0..chunk_size)
|
||||
.map(|x| second_chunks[i] - F::Extension::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(first_product);
|
||||
@ -137,11 +146,22 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
|
||||
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
|
||||
constraints.push(most_significant_diff - most_significant_diff_so_far);
|
||||
|
||||
// Range check `most_significant_diff` to be less than `chunk_size`.
|
||||
let product = (0..chunk_size)
|
||||
.map(|x| most_significant_diff - F::Extension::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(product);
|
||||
let most_significant_diff_bits: Vec<F::Extension> = (0..self.chunk_bits() + 1)
|
||||
.map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)])
|
||||
.collect();
|
||||
|
||||
// Range-check the bits.
|
||||
for &bit in &most_significant_diff_bits {
|
||||
constraints.push(bit * (F::Extension::ONE - bit));
|
||||
}
|
||||
|
||||
let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::Extension::TWO);
|
||||
let two_n = F::Extension::from_canonical_u64(1 << self.chunk_bits());
|
||||
constraints.push((two_n + most_significant_diff) - bits_combined);
|
||||
|
||||
// Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1.
|
||||
let result_bool = vars.local_wires[self.wire_result_bool()];
|
||||
constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]);
|
||||
|
||||
constraints
|
||||
}
|
||||
@ -178,10 +198,10 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
|
||||
|
||||
for i in 0..self.num_chunks {
|
||||
// Range-check the chunks to be less than `chunk_size`.
|
||||
let first_product = (0..chunk_size)
|
||||
let first_product: F = (0..chunk_size)
|
||||
.map(|x| first_chunks[i] - F::from_canonical_usize(x))
|
||||
.product();
|
||||
let second_product = (0..chunk_size)
|
||||
let second_product: F = (0..chunk_size)
|
||||
.map(|x| second_chunks[i] - F::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(first_product);
|
||||
@ -205,11 +225,22 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
|
||||
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
|
||||
constraints.push(most_significant_diff - most_significant_diff_so_far);
|
||||
|
||||
// Range check `most_significant_diff` to be less than `chunk_size`.
|
||||
let product = (0..chunk_size)
|
||||
.map(|x| most_significant_diff - F::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(product);
|
||||
let most_significant_diff_bits: Vec<F> = (0..self.chunk_bits() + 1)
|
||||
.map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)])
|
||||
.collect();
|
||||
|
||||
// Range-check the bits.
|
||||
for &bit in &most_significant_diff_bits {
|
||||
constraints.push(bit * (F::ONE - bit));
|
||||
}
|
||||
|
||||
let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO);
|
||||
let two_n = F::from_canonical_u64(1 << self.chunk_bits());
|
||||
constraints.push((two_n + most_significant_diff) - bits_combined);
|
||||
|
||||
// Iff first <= second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1.
|
||||
let result_bool = vars.local_wires[self.wire_result_bool()];
|
||||
constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]);
|
||||
|
||||
constraints
|
||||
}
|
||||
@ -285,14 +316,29 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
|
||||
constraints
|
||||
.push(builder.sub_extension(most_significant_diff, most_significant_diff_so_far));
|
||||
|
||||
// Range check `most_significant_diff` to be less than `chunk_size`.
|
||||
let mut product = builder.one_extension();
|
||||
for x in 0..chunk_size {
|
||||
let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x));
|
||||
let diff = builder.sub_extension(most_significant_diff, x_f);
|
||||
product = builder.mul_extension(product, diff);
|
||||
let most_significant_diff_bits: Vec<ExtensionTarget<D>> = (0..self.chunk_bits() + 1)
|
||||
.map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)])
|
||||
.collect();
|
||||
|
||||
// Range-check the bits.
|
||||
for &this_bit in &most_significant_diff_bits {
|
||||
let inverse = builder.sub_extension(one, this_bit);
|
||||
constraints.push(builder.mul_extension(this_bit, inverse));
|
||||
}
|
||||
constraints.push(product);
|
||||
|
||||
let two = builder.two();
|
||||
let bits_combined =
|
||||
reduce_with_powers_ext_recursive(builder, &most_significant_diff_bits, two);
|
||||
let two_n =
|
||||
builder.constant_extension(F::Extension::from_canonical_u64(1 << self.chunk_bits()));
|
||||
let sum = builder.add_extension(two_n, most_significant_diff);
|
||||
constraints.push(builder.sub_extension(sum, bits_combined));
|
||||
|
||||
// Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1.
|
||||
let result_bool = vars.local_wires[self.wire_result_bool()];
|
||||
constraints.push(
|
||||
builder.sub_extension(result_bool, most_significant_diff_bits[self.chunk_bits()]),
|
||||
);
|
||||
|
||||
constraints
|
||||
}
|
||||
@ -310,7 +356,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
|
||||
}
|
||||
|
||||
fn num_wires(&self) -> usize {
|
||||
self.wire_intermediate_value(self.num_chunks - 1) + 1
|
||||
4 + 5 * self.num_chunks + (self.chunk_bits() + 1)
|
||||
}
|
||||
|
||||
fn num_constants(&self) -> usize {
|
||||
@ -322,7 +368,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
|
||||
}
|
||||
|
||||
fn num_constraints(&self) -> usize {
|
||||
4 + 5 * self.num_chunks
|
||||
6 + 5 * self.num_chunks + self.chunk_bits()
|
||||
}
|
||||
}
|
||||
|
||||
@ -336,10 +382,10 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ComparisonGenerato
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
let local_target = |input| Target::wire(self.gate_index, input);
|
||||
|
||||
let mut deps = Vec::new();
|
||||
deps.push(local_target(self.gate.wire_first_input()));
|
||||
deps.push(local_target(self.gate.wire_second_input()));
|
||||
deps
|
||||
vec![
|
||||
local_target(self.gate.wire_first_input()),
|
||||
local_target(self.gate.wire_second_input()),
|
||||
]
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
@ -356,7 +402,7 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ComparisonGenerato
|
||||
let first_input_u64 = first_input.to_canonical_u64();
|
||||
let second_input_u64 = second_input.to_canonical_u64();
|
||||
|
||||
debug_assert!(first_input_u64 < second_input_u64);
|
||||
let result = F::from_canonical_usize((first_input_u64 <= second_input_u64) as usize);
|
||||
|
||||
let chunk_size = 1 << self.gate.chunk_bits();
|
||||
let first_input_chunks: Vec<F> = (0..self.gate.num_chunks)
|
||||
@ -395,6 +441,22 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ComparisonGenerato
|
||||
}
|
||||
let most_significant_diff = most_significant_diff_so_far;
|
||||
|
||||
let two_n = F::from_canonical_usize(1 << self.gate.chunk_bits());
|
||||
let two_n_plus_msd = (two_n + most_significant_diff).to_canonical_u64();
|
||||
|
||||
let msd_bits_u64: Vec<u64> = (0..self.gate.chunk_bits() + 1)
|
||||
.scan(two_n_plus_msd, |acc, _| {
|
||||
let tmp = *acc % 2;
|
||||
*acc /= 2;
|
||||
Some(tmp)
|
||||
})
|
||||
.collect();
|
||||
let msd_bits: Vec<F> = msd_bits_u64
|
||||
.iter()
|
||||
.map(|x| F::from_canonical_u64(*x))
|
||||
.collect();
|
||||
|
||||
out_buffer.set_wire(local_wire(self.gate.wire_result_bool()), result);
|
||||
out_buffer.set_wire(
|
||||
local_wire(self.gate.wire_most_significant_diff()),
|
||||
most_significant_diff,
|
||||
@ -418,6 +480,12 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ComparisonGenerato
|
||||
intermediate_values[i],
|
||||
);
|
||||
}
|
||||
for i in 0..self.gate.chunk_bits() + 1 {
|
||||
out_buffer.set_wire(
|
||||
local_wire(self.gate.wire_most_significant_diff_bit(i)),
|
||||
msd_bits[i],
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -451,17 +519,20 @@ mod tests {
|
||||
|
||||
assert_eq!(gate.wire_first_input(), 0);
|
||||
assert_eq!(gate.wire_second_input(), 1);
|
||||
assert_eq!(gate.wire_most_significant_diff(), 2);
|
||||
assert_eq!(gate.wire_first_chunk_val(0), 3);
|
||||
assert_eq!(gate.wire_first_chunk_val(4), 7);
|
||||
assert_eq!(gate.wire_second_chunk_val(0), 8);
|
||||
assert_eq!(gate.wire_second_chunk_val(4), 12);
|
||||
assert_eq!(gate.wire_equality_dummy(0), 13);
|
||||
assert_eq!(gate.wire_equality_dummy(4), 17);
|
||||
assert_eq!(gate.wire_chunks_equal(0), 18);
|
||||
assert_eq!(gate.wire_chunks_equal(4), 22);
|
||||
assert_eq!(gate.wire_intermediate_value(0), 23);
|
||||
assert_eq!(gate.wire_intermediate_value(4), 27);
|
||||
assert_eq!(gate.wire_result_bool(), 2);
|
||||
assert_eq!(gate.wire_most_significant_diff(), 3);
|
||||
assert_eq!(gate.wire_first_chunk_val(0), 4);
|
||||
assert_eq!(gate.wire_first_chunk_val(4), 8);
|
||||
assert_eq!(gate.wire_second_chunk_val(0), 9);
|
||||
assert_eq!(gate.wire_second_chunk_val(4), 13);
|
||||
assert_eq!(gate.wire_equality_dummy(0), 14);
|
||||
assert_eq!(gate.wire_equality_dummy(4), 18);
|
||||
assert_eq!(gate.wire_chunks_equal(0), 19);
|
||||
assert_eq!(gate.wire_chunks_equal(4), 23);
|
||||
assert_eq!(gate.wire_intermediate_value(0), 24);
|
||||
assert_eq!(gate.wire_intermediate_value(4), 28);
|
||||
assert_eq!(gate.wire_most_significant_diff_bit(0), 29);
|
||||
assert_eq!(gate.wire_most_significant_diff_bit(8), 37);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -501,6 +572,8 @@ mod tests {
|
||||
let first_input_u64 = first_input.to_canonical_u64();
|
||||
let second_input_u64 = second_input.to_canonical_u64();
|
||||
|
||||
let result_bool = F::from_bool(first_input_u64 <= second_input_u64);
|
||||
|
||||
let chunk_size = 1 << chunk_bits;
|
||||
let mut first_input_chunks: Vec<F> = (0..num_chunks)
|
||||
.scan(first_input_u64, |acc, _| {
|
||||
@ -538,20 +611,32 @@ mod tests {
|
||||
}
|
||||
let most_significant_diff = most_significant_diff_so_far;
|
||||
|
||||
let two_n_plus_msd =
|
||||
(1 << chunk_bits) as u64 + most_significant_diff.to_canonical_u64();
|
||||
let mut msd_bits: Vec<F> = (0..chunk_bits + 1)
|
||||
.scan(two_n_plus_msd, |acc, _| {
|
||||
let tmp = *acc % 2;
|
||||
*acc /= 2;
|
||||
Some(F::from_canonical_u64(tmp))
|
||||
})
|
||||
.collect();
|
||||
|
||||
v.push(first_input);
|
||||
v.push(second_input);
|
||||
v.push(result_bool);
|
||||
v.push(most_significant_diff);
|
||||
v.append(&mut first_input_chunks);
|
||||
v.append(&mut second_input_chunks);
|
||||
v.append(&mut equality_dummies);
|
||||
v.append(&mut chunks_equal);
|
||||
v.append(&mut intermediate_values);
|
||||
v.append(&mut msd_bits);
|
||||
|
||||
v.iter().map(|&x| x.into()).collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let max: u64 = 1 << num_bits - 1;
|
||||
let max: u64 = 1 << (num_bits - 1);
|
||||
let first_input_u64 = rng.gen_range(0..max);
|
||||
let second_input_u64 = {
|
||||
let mut val = rng.gen_range(0..max);
|
||||
|
||||
@ -337,9 +337,8 @@ mod tests {
|
||||
.map(|b| F::from_canonical_u64(*b))
|
||||
.collect();
|
||||
|
||||
let mut v = Vec::new();
|
||||
v.push(base);
|
||||
v.extend(power_bits_f.clone());
|
||||
let mut v = vec![base];
|
||||
v.extend(power_bits_f);
|
||||
|
||||
let mut intermediate_values = Vec::new();
|
||||
let mut current_intermediate_value = F::ONE;
|
||||
|
||||
@ -10,7 +10,7 @@ use crate::plonk::circuit_data::CircuitConfig;
|
||||
use crate::plonk::config::GenericConfig;
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
use crate::plonk::verifier::verify;
|
||||
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
|
||||
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
|
||||
use crate::util::{log2_ceil, transpose};
|
||||
|
||||
const WITNESS_SIZE: usize = 1 << 5;
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
use log::info;
|
||||
use log::debug;
|
||||
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::RichField;
|
||||
@ -86,7 +86,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Tree<GateRef<F, D>> {
|
||||
}
|
||||
}
|
||||
}
|
||||
info!(
|
||||
debug!(
|
||||
"Found tree with max degree {} and {} constants wires in {:.4}s.",
|
||||
best_degree,
|
||||
best_num_constants,
|
||||
@ -221,12 +221,17 @@ impl<F: RichField + Extendable<D>, const D: usize> Tree<GateRef<F, D>> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use log::info;
|
||||
|
||||
use super::*;
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::gadgets::interpolation::InterpolationGate;
|
||||
use crate::gates::arithmetic_extension::ArithmeticExtensionGate;
|
||||
use crate::gates::arithmetic::ArithmeticExtensionGate;
|
||||
use crate::gates::base_sum::BaseSumGate;
|
||||
use crate::gates::constant::ConstantGate;
|
||||
use crate::gates::gmimc::GMiMCGate;
|
||||
use crate::gates::interpolation::InterpolationGate;
|
||||
use crate::gates::interpolation::HighDegreeInterpolationGate;
|
||||
use crate::gates::noop::NoopGate;
|
||||
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
|
||||
|
||||
@ -243,7 +248,7 @@ mod tests {
|
||||
GateRef::new(ArithmeticExtensionGate { num_ops: 4 }),
|
||||
GateRef::new(BaseSumGate::<4>::new(4)),
|
||||
GateRef::new(GMiMCGate::<F, D, 12>::new()),
|
||||
GateRef::new(InterpolationGate::new(2)),
|
||||
GateRef::new(HighDegreeInterpolationGate::new(2)),
|
||||
];
|
||||
|
||||
let (tree, _, _) = Tree::from_gates(gates.clone());
|
||||
|
||||
@ -318,8 +318,6 @@ impl<F: Extendable<D> + GMiMC<WIDTH>, const D: usize, const WIDTH: usize> Simple
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::convert::TryInto;
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
use crate::field::field_types::Field;
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
use std::convert::TryInto;
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::Range;
|
||||
|
||||
@ -252,8 +251,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F> for Insert
|
||||
|
||||
let local_targets = |inputs: Range<usize>| inputs.map(local_target);
|
||||
|
||||
let mut deps = Vec::new();
|
||||
deps.push(local_target(self.gate.wires_insertion_index()));
|
||||
let mut deps = vec![local_target(self.gate.wires_insertion_index())];
|
||||
deps.extend(local_targets(self.gate.wires_element_to_insert()));
|
||||
for i in 0..self.gate.vec_size {
|
||||
deps.extend(local_targets(self.gate.wires_original_list_item(i)));
|
||||
@ -292,7 +290,7 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F> for Insert
|
||||
vec_size
|
||||
);
|
||||
|
||||
let mut new_vec = orig_vec.clone();
|
||||
let mut new_vec = orig_vec;
|
||||
new_vec.insert(insertion_index, to_insert);
|
||||
|
||||
let mut equality_dummy_vals = Vec::new();
|
||||
@ -377,14 +375,13 @@ mod tests {
|
||||
fn get_wires(orig_vec: Vec<FF>, insertion_index: usize, element_to_insert: FF) -> Vec<FF> {
|
||||
let vec_size = orig_vec.len();
|
||||
|
||||
let mut v = Vec::new();
|
||||
v.push(F::from_canonical_usize(insertion_index));
|
||||
let mut v = vec![F::from_canonical_usize(insertion_index)];
|
||||
v.extend(element_to_insert.0);
|
||||
for j in 0..vec_size {
|
||||
v.extend(orig_vec[j].0);
|
||||
}
|
||||
|
||||
let mut new_vec = orig_vec.clone();
|
||||
let mut new_vec = orig_vec;
|
||||
new_vec.insert(insertion_index, element_to_insert);
|
||||
let mut equality_dummy_vals = Vec::new();
|
||||
for i in 0..=vec_size {
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
use std::convert::TryInto;
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::Range;
|
||||
|
||||
@ -6,6 +5,7 @@ use crate::field::extension_field::algebra::PolynomialCoeffsAlgebra;
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::{Extendable, FieldExtension};
|
||||
use crate::field::interpolation::interpolant;
|
||||
use crate::gadgets::interpolation::InterpolationGate;
|
||||
use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
@ -14,19 +14,20 @@ use crate::iop::wire::Wire;
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
use crate::polynomial::polynomial::PolynomialCoeffs;
|
||||
use crate::polynomial::PolynomialCoeffs;
|
||||
|
||||
/// Interpolates a polynomial, whose points are a (base field) coset of the multiplicative subgroup
|
||||
/// with the given size, and whose values are extension field elements, given by input wires.
|
||||
/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point.
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct InterpolationGate<F: Extendable<D>, const D: usize> {
|
||||
/// Interpolation gate with constraints of degree at most `1<<subgroup_bits`.
|
||||
/// `eval_unfiltered_recursively` uses less gates than `LowDegreeInterpolationGate`.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub(crate) struct HighDegreeInterpolationGate<F: RichField + Extendable<D>, const D: usize> {
|
||||
pub subgroup_bits: usize,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: Extendable<D>, const D: usize> InterpolationGate<F, D> {
|
||||
pub fn new(subgroup_bits: usize) -> Self {
|
||||
impl<F: Extendable<D>, const D: usize> InterpolationGate<F, D>
|
||||
for HighDegreeInterpolationGate<F, D>
|
||||
{
|
||||
fn new(subgroup_bits: usize) -> Self {
|
||||
Self {
|
||||
subgroup_bits,
|
||||
_phantom: PhantomData,
|
||||
@ -36,60 +37,9 @@ impl<F: Extendable<D>, const D: usize> InterpolationGate<F, D> {
|
||||
fn num_points(&self) -> usize {
|
||||
1 << self.subgroup_bits
|
||||
}
|
||||
}
|
||||
|
||||
/// Wire index of the coset shift.
|
||||
pub fn wire_shift(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn start_values(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
/// Wire indices of the `i`th interpolant value.
|
||||
pub fn wires_value(&self, i: usize) -> Range<usize> {
|
||||
debug_assert!(i < self.num_points());
|
||||
let start = self.start_values() + i * D;
|
||||
start..start + D
|
||||
}
|
||||
|
||||
fn start_evaluation_point(&self) -> usize {
|
||||
self.start_values() + self.num_points() * D
|
||||
}
|
||||
|
||||
/// Wire indices of the point to evaluate the interpolant at.
|
||||
pub fn wires_evaluation_point(&self) -> Range<usize> {
|
||||
let start = self.start_evaluation_point();
|
||||
start..start + D
|
||||
}
|
||||
|
||||
fn start_evaluation_value(&self) -> usize {
|
||||
self.start_evaluation_point() + D
|
||||
}
|
||||
|
||||
/// Wire indices of the interpolated value.
|
||||
pub fn wires_evaluation_value(&self) -> Range<usize> {
|
||||
let start = self.start_evaluation_value();
|
||||
start..start + D
|
||||
}
|
||||
|
||||
fn start_coeffs(&self) -> usize {
|
||||
self.start_evaluation_value() + D
|
||||
}
|
||||
|
||||
/// The number of routed wires required in the typical usage of this gate, where the points to
|
||||
/// interpolate, the evaluation point, and the corresponding value are all routed.
|
||||
pub(crate) fn num_routed_wires(&self) -> usize {
|
||||
self.start_coeffs()
|
||||
}
|
||||
|
||||
/// Wire indices of the interpolant's `i`th coefficient.
|
||||
pub fn wires_coeff(&self, i: usize) -> Range<usize> {
|
||||
debug_assert!(i < self.num_points());
|
||||
let start = self.start_coeffs() + i * D;
|
||||
start..start + D
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> HighDegreeInterpolationGate<F, D> {
|
||||
/// End of wire indices, exclusive.
|
||||
fn end(&self) -> usize {
|
||||
self.start_coeffs() + self.num_points() * D
|
||||
@ -121,14 +71,16 @@ impl<F: Extendable<D>, const D: usize> InterpolationGate<F, D> {
|
||||
g.powers()
|
||||
.take(size)
|
||||
.map(move |x| {
|
||||
let subgroup_element = builder.constant(x.into());
|
||||
let subgroup_element = builder.constant(x);
|
||||
builder.scalar_mul_ext(subgroup_element, shift)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Extendable<D>, const D: usize> Gate<F, D> for InterpolationGate<F, D> {
|
||||
impl<F: Extendable<D>, const D: usize> Gate<F, D>
|
||||
for HighDegreeInterpolationGate<F, D>
|
||||
{
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}<D={}>", self, D)
|
||||
}
|
||||
@ -221,7 +173,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InterpolationGate<F, D> {
|
||||
) -> Vec<Box<dyn WitnessGenerator<F>>> {
|
||||
let gen = InterpolationGenerator::<F, D> {
|
||||
gate_index,
|
||||
gate: self.clone(),
|
||||
gate: *self,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
vec![Box::new(gen.adapter())]
|
||||
@ -251,7 +203,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for InterpolationGate<F, D> {
|
||||
#[derive(Debug)]
|
||||
struct InterpolationGenerator<F: Extendable<D>, const D: usize> {
|
||||
gate_index: usize,
|
||||
gate: InterpolationGate<F, D>,
|
||||
gate: HighDegreeInterpolationGate<F, D>,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
@ -321,17 +273,18 @@ mod tests {
|
||||
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::gadgets::interpolation::InterpolationGate;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
|
||||
use crate::gates::interpolation::InterpolationGate;
|
||||
use crate::gates::interpolation::HighDegreeInterpolationGate;
|
||||
use crate::hash::hash_types::HashOut;
|
||||
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
|
||||
use crate::plonk::vars::EvaluationVars;
|
||||
use crate::polynomial::polynomial::PolynomialCoeffs;
|
||||
use crate::polynomial::PolynomialCoeffs;
|
||||
|
||||
#[test]
|
||||
fn wire_indices() {
|
||||
let gate = InterpolationGate::<GoldilocksField, 4> {
|
||||
let gate = HighDegreeInterpolationGate::<GoldilocksField, 4> {
|
||||
subgroup_bits: 1,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
@ -350,7 +303,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn low_degree() {
|
||||
test_low_degree::<GoldilocksField, _, 4>(InterpolationGate::new(2));
|
||||
test_low_degree::<GoldilocksField, _, 4>(HighDegreeInterpolationGate::new(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -358,7 +311,7 @@ mod tests {
|
||||
const D: usize = 2;
|
||||
type C = PoseidonGoldilocksConfig;
|
||||
type F = <C as GenericConfig<D>>::F;
|
||||
test_eval_fns::<F, C, _, D>(InterpolationGate::new(2))
|
||||
test_eval_fns::<F, C, _, D>(HighDegreeInterpolationGate::new(2))
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -370,7 +323,7 @@ mod tests {
|
||||
|
||||
/// Returns the local wires for an interpolation gate for given coeffs, points and eval point.
|
||||
fn get_wires(
|
||||
gate: &InterpolationGate<F, D>,
|
||||
gate: &HighDegreeInterpolationGate<F, D>,
|
||||
shift: F,
|
||||
coeffs: PolynomialCoeffs<FF>,
|
||||
eval_point: FF,
|
||||
@ -392,7 +345,7 @@ mod tests {
|
||||
let shift = F::rand();
|
||||
let coeffs = PolynomialCoeffs::new(vec![FF::rand(), FF::rand()]);
|
||||
let eval_point = FF::rand();
|
||||
let gate = InterpolationGate::<F, D>::new(1);
|
||||
let gate = HighDegreeInterpolationGate::<F, D>::new(1);
|
||||
let vars = EvaluationVars {
|
||||
local_constants: &[],
|
||||
local_wires: &get_wires(&gate, shift, coeffs, eval_point),
|
||||
|
||||
459
src/gates/low_degree_interpolation.rs
Normal file
459
src/gates/low_degree_interpolation.rs
Normal file
@ -0,0 +1,459 @@
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::field::extension_field::algebra::PolynomialCoeffsAlgebra;
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::{Extendable, FieldExtension};
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::field::interpolation::interpolant;
|
||||
use crate::gadgets::interpolation::InterpolationGate;
|
||||
use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::wire::Wire;
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
use crate::polynomial::PolynomialCoeffs;
|
||||
|
||||
/// Interpolation gate with constraints of degree 2.
|
||||
/// `eval_unfiltered_recursively` uses more gates than `HighDegreeInterpolationGate`.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub(crate) struct LowDegreeInterpolationGate<F: RichField + Extendable<D>, const D: usize> {
|
||||
pub subgroup_bits: usize,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> InterpolationGate<F, D>
|
||||
for LowDegreeInterpolationGate<F, D>
|
||||
{
|
||||
fn new(subgroup_bits: usize) -> Self {
|
||||
Self {
|
||||
subgroup_bits,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn num_points(&self) -> usize {
|
||||
1 << self.subgroup_bits
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> LowDegreeInterpolationGate<F, D> {
|
||||
/// `powers_shift(i)` is the wire index of `wire_shift^i`.
|
||||
pub fn powers_shift(&self, i: usize) -> usize {
|
||||
debug_assert!(0 < i && i < self.num_points());
|
||||
if i == 1 {
|
||||
return self.wire_shift();
|
||||
}
|
||||
self.end_coeffs() + i - 2
|
||||
}
|
||||
|
||||
/// `powers_evalutation_point(i)` is the wire index of `evalutation_point^i`.
|
||||
pub fn powers_evaluation_point(&self, i: usize) -> Range<usize> {
|
||||
debug_assert!(0 < i && i < self.num_points());
|
||||
if i == 1 {
|
||||
return self.wires_evaluation_point();
|
||||
}
|
||||
let start = self.end_coeffs() + self.num_points() - 2 + (i - 2) * D;
|
||||
start..start + D
|
||||
}
|
||||
|
||||
/// End of wire indices, exclusive.
|
||||
fn end(&self) -> usize {
|
||||
self.powers_evaluation_point(self.num_points() - 1).end
|
||||
}
|
||||
|
||||
/// The domain of the points we're interpolating.
|
||||
fn coset(&self, shift: F) -> impl Iterator<Item = F> {
|
||||
let g = F::primitive_root_of_unity(self.subgroup_bits);
|
||||
let size = 1 << self.subgroup_bits;
|
||||
// Speed matters here, so we avoid `cyclic_subgroup_coset_known_order` which allocates.
|
||||
g.powers().take(size).map(move |x| x * shift)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for LowDegreeInterpolationGate<F, D> {
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}<D={}>", self, D)
|
||||
}
|
||||
|
||||
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
let coeffs = (0..self.num_points())
|
||||
.map(|i| vars.get_local_ext_algebra(self.wires_coeff(i)))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut powers_shift = (1..self.num_points())
|
||||
.map(|i| vars.local_wires[self.powers_shift(i)])
|
||||
.collect::<Vec<_>>();
|
||||
let shift = powers_shift[0];
|
||||
for i in 1..self.num_points() - 1 {
|
||||
constraints.push(powers_shift[i - 1] * shift - powers_shift[i]);
|
||||
}
|
||||
powers_shift.insert(0, F::Extension::ONE);
|
||||
// `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient.
|
||||
// Then, `altered(w^i) = original(shift*w^i)`.
|
||||
let altered_coeffs = coeffs
|
||||
.iter()
|
||||
.zip(powers_shift)
|
||||
.map(|(&c, p)| c.scalar_mul(p))
|
||||
.collect::<Vec<_>>();
|
||||
let interpolant = PolynomialCoeffsAlgebra::new(coeffs);
|
||||
let altered_interpolant = PolynomialCoeffsAlgebra::new(altered_coeffs);
|
||||
|
||||
for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits)
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
{
|
||||
let value = vars.get_local_ext_algebra(self.wires_value(i));
|
||||
let computed_value = altered_interpolant.eval_base(point);
|
||||
constraints.extend(&(value - computed_value).to_basefield_array());
|
||||
}
|
||||
|
||||
let evaluation_point_powers = (1..self.num_points())
|
||||
.map(|i| vars.get_local_ext_algebra(self.powers_evaluation_point(i)))
|
||||
.collect::<Vec<_>>();
|
||||
let evaluation_point = evaluation_point_powers[0];
|
||||
for i in 1..self.num_points() - 1 {
|
||||
constraints.extend(
|
||||
(evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i])
|
||||
.to_basefield_array(),
|
||||
);
|
||||
}
|
||||
let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value());
|
||||
let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers);
|
||||
constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array());
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
let coeffs = (0..self.num_points())
|
||||
.map(|i| vars.get_local_ext(self.wires_coeff(i)))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut powers_shift = (1..self.num_points())
|
||||
.map(|i| vars.local_wires[self.powers_shift(i)])
|
||||
.collect::<Vec<_>>();
|
||||
let shift = powers_shift[0];
|
||||
for i in 1..self.num_points() - 1 {
|
||||
constraints.push(powers_shift[i - 1] * shift - powers_shift[i]);
|
||||
}
|
||||
powers_shift.insert(0, F::ONE);
|
||||
// `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient.
|
||||
// Then, `altered(w^i) = original(shift*w^i)`.
|
||||
let altered_coeffs = coeffs
|
||||
.iter()
|
||||
.zip(powers_shift)
|
||||
.map(|(&c, p)| c.scalar_mul(p))
|
||||
.collect::<Vec<_>>();
|
||||
let interpolant = PolynomialCoeffs::new(coeffs);
|
||||
let altered_interpolant = PolynomialCoeffs::new(altered_coeffs);
|
||||
|
||||
for (i, point) in F::two_adic_subgroup(self.subgroup_bits)
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
{
|
||||
let value = vars.get_local_ext(self.wires_value(i));
|
||||
let computed_value = altered_interpolant.eval_base(point);
|
||||
constraints.extend(&(value - computed_value).to_basefield_array());
|
||||
}
|
||||
|
||||
let evaluation_point_powers = (1..self.num_points())
|
||||
.map(|i| vars.get_local_ext(self.powers_evaluation_point(i)))
|
||||
.collect::<Vec<_>>();
|
||||
let evaluation_point = evaluation_point_powers[0];
|
||||
for i in 1..self.num_points() - 1 {
|
||||
constraints.extend(
|
||||
(evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i])
|
||||
.to_basefield_array(),
|
||||
);
|
||||
}
|
||||
let evaluation_value = vars.get_local_ext(self.wires_evaluation_value());
|
||||
let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers);
|
||||
constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array());
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
vars: EvaluationTargets<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
let coeffs = (0..self.num_points())
|
||||
.map(|i| vars.get_local_ext_algebra(self.wires_coeff(i)))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut powers_shift = (1..self.num_points())
|
||||
.map(|i| vars.local_wires[self.powers_shift(i)])
|
||||
.collect::<Vec<_>>();
|
||||
let shift = powers_shift[0];
|
||||
for i in 1..self.num_points() - 1 {
|
||||
constraints.push(builder.mul_sub_extension(
|
||||
powers_shift[i - 1],
|
||||
shift,
|
||||
powers_shift[i],
|
||||
));
|
||||
}
|
||||
powers_shift.insert(0, builder.one_extension());
|
||||
// `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient.
|
||||
// Then, `altered(w^i) = original(shift*w^i)`.
|
||||
let altered_coeffs = coeffs
|
||||
.iter()
|
||||
.zip(powers_shift)
|
||||
.map(|(&c, p)| builder.scalar_mul_ext_algebra(p, c))
|
||||
.collect::<Vec<_>>();
|
||||
let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs);
|
||||
let altered_interpolant = PolynomialCoeffsExtAlgebraTarget(altered_coeffs);
|
||||
|
||||
for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits)
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
{
|
||||
let value = vars.get_local_ext_algebra(self.wires_value(i));
|
||||
let point = builder.constant_extension(point);
|
||||
let computed_value = altered_interpolant.eval_scalar(builder, point);
|
||||
constraints.extend(
|
||||
&builder
|
||||
.sub_ext_algebra(value, computed_value)
|
||||
.to_ext_target_array(),
|
||||
);
|
||||
}
|
||||
|
||||
let evaluation_point_powers = (1..self.num_points())
|
||||
.map(|i| vars.get_local_ext_algebra(self.powers_evaluation_point(i)))
|
||||
.collect::<Vec<_>>();
|
||||
let evaluation_point = evaluation_point_powers[0];
|
||||
for i in 1..self.num_points() - 1 {
|
||||
let neg_one_ext = builder.neg_one_extension();
|
||||
let neg_new_power =
|
||||
builder.scalar_mul_ext_algebra(neg_one_ext, evaluation_point_powers[i]);
|
||||
let constraint = builder.mul_add_ext_algebra(
|
||||
evaluation_point,
|
||||
evaluation_point_powers[i - 1],
|
||||
neg_new_power,
|
||||
);
|
||||
constraints.extend(constraint.to_ext_target_array());
|
||||
}
|
||||
let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value());
|
||||
let computed_evaluation_value =
|
||||
interpolant.eval_with_powers(builder, &evaluation_point_powers);
|
||||
// let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point());
|
||||
// let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value());
|
||||
// let computed_evaluation_value = interpolant.eval(builder, evaluation_point);
|
||||
constraints.extend(
|
||||
&builder
|
||||
.sub_ext_algebra(evaluation_value, computed_evaluation_value)
|
||||
.to_ext_target_array(),
|
||||
);
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn generators(
|
||||
&self,
|
||||
gate_index: usize,
|
||||
_local_constants: &[F],
|
||||
) -> Vec<Box<dyn WitnessGenerator<F>>> {
|
||||
let gen = InterpolationGenerator::<F, D> {
|
||||
gate_index,
|
||||
gate: *self,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
vec![Box::new(gen.adapter())]
|
||||
}
|
||||
|
||||
fn num_wires(&self) -> usize {
|
||||
self.end()
|
||||
}
|
||||
|
||||
fn num_constants(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn degree(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn num_constraints(&self) -> usize {
|
||||
// `num_points * D` constraints to check for consistency between the coefficients and the
|
||||
// point-value pairs, plus `D` constraints for the evaluation value, plus `(D+1)*(num_points-2)`
|
||||
// to check power constraints for evaluation point and shift.
|
||||
self.num_points() * D + D + (D + 1) * (self.num_points() - 2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InterpolationGenerator<F: RichField + Extendable<D>, const D: usize> {
|
||||
gate_index: usize,
|
||||
gate: LowDegreeInterpolationGate<F, D>,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
|
||||
for InterpolationGenerator<F, D>
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
let local_target = |input| {
|
||||
Target::Wire(Wire {
|
||||
gate: self.gate_index,
|
||||
input,
|
||||
})
|
||||
};
|
||||
|
||||
let local_targets = |inputs: Range<usize>| inputs.map(local_target);
|
||||
|
||||
let num_points = self.gate.num_points();
|
||||
let mut deps = Vec::with_capacity(1 + D + num_points * D);
|
||||
|
||||
deps.push(local_target(self.gate.wire_shift()));
|
||||
deps.extend(local_targets(self.gate.wires_evaluation_point()));
|
||||
for i in 0..num_points {
|
||||
deps.extend(local_targets(self.gate.wires_value(i)));
|
||||
}
|
||||
deps
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
let local_wire = |input| Wire {
|
||||
gate: self.gate_index,
|
||||
input,
|
||||
};
|
||||
|
||||
let get_local_wire = |input| witness.get_wire(local_wire(input));
|
||||
|
||||
let get_local_ext = |wire_range: Range<usize>| {
|
||||
debug_assert_eq!(wire_range.len(), D);
|
||||
let values = wire_range.map(get_local_wire).collect::<Vec<_>>();
|
||||
let arr = values.try_into().unwrap();
|
||||
F::Extension::from_basefield_array(arr)
|
||||
};
|
||||
|
||||
let wire_shift = get_local_wire(self.gate.wire_shift());
|
||||
|
||||
for (i, power) in wire_shift
|
||||
.powers()
|
||||
.take(self.gate.num_points())
|
||||
.enumerate()
|
||||
.skip(2)
|
||||
{
|
||||
out_buffer.set_wire(local_wire(self.gate.powers_shift(i)), power);
|
||||
}
|
||||
|
||||
// Compute the interpolant.
|
||||
let points = self.gate.coset(wire_shift);
|
||||
let points = points
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, point)| (point.into(), get_local_ext(self.gate.wires_value(i))))
|
||||
.collect::<Vec<_>>();
|
||||
let interpolant = interpolant(&points);
|
||||
|
||||
for (i, &coeff) in interpolant.coeffs.iter().enumerate() {
|
||||
let wires = self.gate.wires_coeff(i).map(local_wire);
|
||||
out_buffer.set_ext_wires(wires, coeff);
|
||||
}
|
||||
|
||||
let evaluation_point = get_local_ext(self.gate.wires_evaluation_point());
|
||||
for (i, power) in evaluation_point
|
||||
.powers()
|
||||
.take(self.gate.num_points())
|
||||
.enumerate()
|
||||
.skip(2)
|
||||
{
|
||||
out_buffer.set_extension_target(
|
||||
ExtensionTarget::from_range(self.gate_index, self.gate.powers_evaluation_point(i)),
|
||||
power,
|
||||
);
|
||||
}
|
||||
let evaluation_value = interpolant.eval(evaluation_point);
|
||||
let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire);
|
||||
out_buffer.set_ext_wires(evaluation_value_wires, evaluation_value);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
|
||||
use crate::field::extension_field::quadratic::QuadraticExtension;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::gadgets::interpolation::InterpolationGate;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
|
||||
use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate;
|
||||
use crate::hash::hash_types::HashOut;
|
||||
use crate::plonk::vars::EvaluationVars;
|
||||
use crate::polynomial::PolynomialCoeffs;
|
||||
|
||||
#[test]
|
||||
fn low_degree() {
|
||||
test_low_degree::<GoldilocksField, _, 4>(LowDegreeInterpolationGate::new(4));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_fns() -> Result<()> {
|
||||
test_eval_fns::<GoldilocksField, _, 4>(LowDegreeInterpolationGate::new(4))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gate_constraint() {
|
||||
type F = GoldilocksField;
|
||||
type FF = QuadraticExtension<GoldilocksField>;
|
||||
const D: usize = 2;
|
||||
|
||||
/// Returns the local wires for an interpolation gate for given coeffs, points and eval point.
|
||||
fn get_wires(
|
||||
gate: &LowDegreeInterpolationGate<F, D>,
|
||||
shift: F,
|
||||
coeffs: PolynomialCoeffs<FF>,
|
||||
eval_point: FF,
|
||||
) -> Vec<FF> {
|
||||
let points = gate.coset(shift);
|
||||
let mut v = vec![shift];
|
||||
for x in points {
|
||||
v.extend(coeffs.eval(x.into()).0);
|
||||
}
|
||||
v.extend(eval_point.0);
|
||||
v.extend(coeffs.eval(eval_point).0);
|
||||
for i in 0..coeffs.len() {
|
||||
v.extend(coeffs.coeffs[i].0);
|
||||
}
|
||||
v.extend(shift.powers().skip(2).take(gate.num_points() - 2));
|
||||
v.extend(
|
||||
eval_point
|
||||
.powers()
|
||||
.skip(2)
|
||||
.take(gate.num_points() - 2)
|
||||
.flat_map(|ff| ff.0),
|
||||
);
|
||||
v.iter().map(|&x| x.into()).collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
// Get a working row for LowDegreeInterpolationGate.
|
||||
let subgroup_bits = 4;
|
||||
let shift = F::rand();
|
||||
let coeffs = PolynomialCoeffs::new(FF::rand_vec(1 << subgroup_bits));
|
||||
let eval_point = FF::rand();
|
||||
let gate = LowDegreeInterpolationGate::<F, D>::new(subgroup_bits);
|
||||
let vars = EvaluationVars {
|
||||
local_constants: &[],
|
||||
local_wires: &get_wires(&gate, shift, coeffs, eval_point),
|
||||
public_inputs_hash: &HashOut::rand(),
|
||||
};
|
||||
|
||||
assert!(
|
||||
gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()),
|
||||
"Gate constraints are not satisfied."
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -1,8 +1,10 @@
|
||||
// Gates have `new` methods that return `GateRef`s.
|
||||
#![allow(clippy::new_ret_no_self)]
|
||||
|
||||
pub mod arithmetic;
|
||||
pub mod arithmetic_base;
|
||||
pub mod arithmetic_extension;
|
||||
pub mod arithmetic_u32;
|
||||
pub mod assert_le;
|
||||
pub mod base_sum;
|
||||
pub mod comparison;
|
||||
pub mod constant;
|
||||
@ -12,12 +14,16 @@ pub mod gate_tree;
|
||||
pub mod gmimc;
|
||||
pub mod insertion;
|
||||
pub mod interpolation;
|
||||
pub mod low_degree_interpolation;
|
||||
pub mod multiplication_extension;
|
||||
pub mod noop;
|
||||
pub mod poseidon;
|
||||
pub(crate) mod poseidon_mds;
|
||||
pub(crate) mod public_input;
|
||||
pub mod random_access;
|
||||
pub mod reducing;
|
||||
pub mod reducing_extension;
|
||||
pub mod subtraction_u32;
|
||||
pub mod switch;
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
204
src/gates/multiplication_extension.rs
Normal file
204
src/gates/multiplication_extension.rs
Normal file
@ -0,0 +1,204 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::extension_field::FieldExtension;
|
||||
use crate::field::field_types::RichField;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::circuit_data::CircuitConfig;
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
|
||||
/// A gate which can perform a weighted multiplication, i.e. `result = c0 x y`. If the config
|
||||
/// supports enough routed wires, it can support several such operations in one gate.
|
||||
#[derive(Debug)]
|
||||
pub struct MulExtensionGate<const D: usize> {
|
||||
/// Number of multiplications performed by the gate.
|
||||
pub num_ops: usize,
|
||||
}
|
||||
|
||||
impl<const D: usize> MulExtensionGate<D> {
|
||||
pub fn new_from_config(config: &CircuitConfig) -> Self {
|
||||
Self {
|
||||
num_ops: Self::num_ops(config),
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine the maximum number of operations that can fit in one gate for the given config.
|
||||
pub(crate) fn num_ops(config: &CircuitConfig) -> usize {
|
||||
let wires_per_op = 3 * D;
|
||||
config.num_routed_wires / wires_per_op
|
||||
}
|
||||
|
||||
pub fn wires_ith_multiplicand_0(i: usize) -> Range<usize> {
|
||||
3 * D * i..3 * D * i + D
|
||||
}
|
||||
pub fn wires_ith_multiplicand_1(i: usize) -> Range<usize> {
|
||||
3 * D * i + D..3 * D * i + 2 * D
|
||||
}
|
||||
pub fn wires_ith_output(i: usize) -> Range<usize> {
|
||||
3 * D * i + 2 * D..3 * D * i + 3 * D
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for MulExtensionGate<D> {
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}", self)
|
||||
}
|
||||
|
||||
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
|
||||
let const_0 = vars.local_constants[0];
|
||||
|
||||
let mut constraints = Vec::new();
|
||||
for i in 0..self.num_ops {
|
||||
let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i));
|
||||
let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i));
|
||||
let output = vars.get_local_ext_algebra(Self::wires_ith_output(i));
|
||||
let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(const_0);
|
||||
|
||||
constraints.extend((output - computed_output).to_basefield_array());
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let const_0 = vars.local_constants[0];
|
||||
|
||||
let mut constraints = Vec::new();
|
||||
for i in 0..self.num_ops {
|
||||
let multiplicand_0 = vars.get_local_ext(Self::wires_ith_multiplicand_0(i));
|
||||
let multiplicand_1 = vars.get_local_ext(Self::wires_ith_multiplicand_1(i));
|
||||
let output = vars.get_local_ext(Self::wires_ith_output(i));
|
||||
let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(const_0);
|
||||
|
||||
constraints.extend((output - computed_output).to_basefield_array());
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
vars: EvaluationTargets<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let const_0 = vars.local_constants[0];
|
||||
|
||||
let mut constraints = Vec::new();
|
||||
for i in 0..self.num_ops {
|
||||
let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i));
|
||||
let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i));
|
||||
let output = vars.get_local_ext_algebra(Self::wires_ith_output(i));
|
||||
let computed_output = {
|
||||
let mul = builder.mul_ext_algebra(multiplicand_0, multiplicand_1);
|
||||
builder.scalar_mul_ext_algebra(const_0, mul)
|
||||
};
|
||||
|
||||
let diff = builder.sub_ext_algebra(output, computed_output);
|
||||
constraints.extend(diff.to_ext_target_array());
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn generators(
|
||||
&self,
|
||||
gate_index: usize,
|
||||
local_constants: &[F],
|
||||
) -> Vec<Box<dyn WitnessGenerator<F>>> {
|
||||
(0..self.num_ops)
|
||||
.map(|i| {
|
||||
let g: Box<dyn WitnessGenerator<F>> = Box::new(
|
||||
MulExtensionGenerator {
|
||||
gate_index,
|
||||
const_0: local_constants[0],
|
||||
i,
|
||||
}
|
||||
.adapter(),
|
||||
);
|
||||
g
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn num_wires(&self) -> usize {
|
||||
self.num_ops * 3 * D
|
||||
}
|
||||
|
||||
fn num_constants(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn degree(&self) -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn num_constraints(&self) -> usize {
|
||||
self.num_ops * D
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct MulExtensionGenerator<F: RichField + Extendable<D>, const D: usize> {
|
||||
gate_index: usize,
|
||||
const_0: F,
|
||||
i: usize,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
|
||||
for MulExtensionGenerator<F, D>
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
MulExtensionGate::<D>::wires_ith_multiplicand_0(self.i)
|
||||
.chain(MulExtensionGate::<D>::wires_ith_multiplicand_1(self.i))
|
||||
.map(|i| Target::wire(self.gate_index, i))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
let extract_extension = |range: Range<usize>| -> F::Extension {
|
||||
let t = ExtensionTarget::from_range(self.gate_index, range);
|
||||
witness.get_extension_target(t)
|
||||
};
|
||||
|
||||
let multiplicand_0 =
|
||||
extract_extension(MulExtensionGate::<D>::wires_ith_multiplicand_0(self.i));
|
||||
let multiplicand_1 =
|
||||
extract_extension(MulExtensionGate::<D>::wires_ith_multiplicand_1(self.i));
|
||||
|
||||
let output_target = ExtensionTarget::from_range(
|
||||
self.gate_index,
|
||||
MulExtensionGate::<D>::wires_ith_output(self.i),
|
||||
);
|
||||
|
||||
let computed_output = (multiplicand_0 * multiplicand_1).scalar_mul(self.const_0);
|
||||
|
||||
out_buffer.set_extension_target(output_target, computed_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
|
||||
use crate::gates::multiplication_extension::MulExtensionGate;
|
||||
use crate::plonk::circuit_data::CircuitConfig;
|
||||
|
||||
#[test]
|
||||
fn low_degree() {
|
||||
let gate = MulExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config());
|
||||
test_low_degree::<GoldilocksField, _, 4>(gate);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_fns() -> Result<()> {
|
||||
let gate = MulExtensionGate::new_from_config(&CircuitConfig::standard_recursion_config());
|
||||
test_eval_fns::<GoldilocksField, _, 4>(gate)
|
||||
}
|
||||
}
|
||||
@ -1,4 +1,3 @@
|
||||
use std::convert::TryInto;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
@ -47,44 +46,59 @@ impl<F: Extendable<D>, const D: usize> PoseidonGate<F, D> {
|
||||
/// is useful for ordering hashes in Merkle proofs. Otherwise, this should be set to 0.
|
||||
pub const WIRE_SWAP: usize = 2 * SPONGE_WIDTH;
|
||||
|
||||
const START_DELTA: usize = 2 * WIDTH + 1;
|
||||
|
||||
/// A wire which stores `swap * (input[i + 4] - input[i])`; used to compute the swapped inputs.
|
||||
fn wire_delta(i: usize) -> usize {
|
||||
assert!(i < 4);
|
||||
Self::START_DELTA + i
|
||||
}
|
||||
|
||||
const START_FULL_0: usize = Self::START_DELTA + 4;
|
||||
|
||||
/// A wire which stores the input of the `i`-th S-box of the `round`-th round of the first set
|
||||
/// of full rounds.
|
||||
fn wire_full_sbox_0(round: usize, i: usize) -> usize {
|
||||
debug_assert!(
|
||||
round != 0,
|
||||
"First round S-box inputs are not stored as wires"
|
||||
);
|
||||
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
|
||||
debug_assert!(i < SPONGE_WIDTH);
|
||||
2 * SPONGE_WIDTH + 1 + SPONGE_WIDTH * round + i
|
||||
debug_assert!(i < WIDTH);
|
||||
Self::START_FULL_0 + WIDTH * (round - 1) + i
|
||||
}
|
||||
|
||||
const START_PARTIAL: usize = Self::START_FULL_0 + WIDTH * (poseidon::HALF_N_FULL_ROUNDS - 1);
|
||||
|
||||
/// A wire which stores the input of the S-box of the `round`-th round of the partial rounds.
|
||||
fn wire_partial_sbox(round: usize) -> usize {
|
||||
debug_assert!(round < poseidon::N_PARTIAL_ROUNDS);
|
||||
2 * SPONGE_WIDTH + 1 + SPONGE_WIDTH * poseidon::HALF_N_FULL_ROUNDS + round
|
||||
Self::START_PARTIAL + round
|
||||
}
|
||||
|
||||
const START_FULL_1: usize = Self::START_PARTIAL + poseidon::N_PARTIAL_ROUNDS;
|
||||
|
||||
/// A wire which stores the input of the `i`-th S-box of the `round`-th round of the second set
|
||||
/// of full rounds.
|
||||
fn wire_full_sbox_1(round: usize, i: usize) -> usize {
|
||||
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
|
||||
debug_assert!(i < SPONGE_WIDTH);
|
||||
2 * SPONGE_WIDTH
|
||||
+ 1
|
||||
+ SPONGE_WIDTH * (poseidon::HALF_N_FULL_ROUNDS + round)
|
||||
+ poseidon::N_PARTIAL_ROUNDS
|
||||
+ i
|
||||
debug_assert!(i < WIDTH);
|
||||
Self::START_FULL_1 + WIDTH * round + i
|
||||
}
|
||||
|
||||
/// End of wire indices, exclusive.
|
||||
fn end() -> usize {
|
||||
2 * SPONGE_WIDTH
|
||||
+ 1
|
||||
+ SPONGE_WIDTH * poseidon::N_FULL_ROUNDS_TOTAL
|
||||
+ poseidon::N_PARTIAL_ROUNDS
|
||||
Self::START_FULL_1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
|
||||
impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize> Gate<F, D>
|
||||
for PoseidonGate<F, D, WIDTH>
|
||||
where
|
||||
[(); WIDTH - 1]:,
|
||||
{
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}<SPONGE_WIDTH={}>", self, SPONGE_WIDTH)
|
||||
format!("{:?}<WIDTH={}>", self, WIDTH)
|
||||
}
|
||||
|
||||
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
|
||||
@ -94,69 +108,79 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
|
||||
let swap = vars.local_wires[Self::WIRE_SWAP];
|
||||
constraints.push(swap * (swap - F::Extension::ONE));
|
||||
|
||||
let mut state = Vec::with_capacity(SPONGE_WIDTH);
|
||||
// Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`.
|
||||
for i in 0..4 {
|
||||
let a = vars.local_wires[i];
|
||||
let b = vars.local_wires[i + 4];
|
||||
state.push(a + swap * (b - a));
|
||||
}
|
||||
for i in 0..4 {
|
||||
let a = vars.local_wires[i + 4];
|
||||
let b = vars.local_wires[i];
|
||||
state.push(a + swap * (b - a));
|
||||
}
|
||||
for i in 8..SPONGE_WIDTH {
|
||||
state.push(vars.local_wires[i]);
|
||||
let input_lhs = vars.local_wires[Self::wire_input(i)];
|
||||
let input_rhs = vars.local_wires[Self::wire_input(i + 4)];
|
||||
let delta_i = vars.local_wires[Self::wire_delta(i)];
|
||||
constraints.push(swap * (input_rhs - input_lhs) - delta_i);
|
||||
}
|
||||
|
||||
// Compute the possibly-swapped input layer.
|
||||
let mut state = [F::Extension::ZERO; WIDTH];
|
||||
for i in 0..4 {
|
||||
let delta_i = vars.local_wires[Self::wire_delta(i)];
|
||||
let input_lhs = Self::wire_input(i);
|
||||
let input_rhs = Self::wire_input(i + 4);
|
||||
state[i] = vars.local_wires[input_lhs] + delta_i;
|
||||
state[i + 4] = vars.local_wires[input_rhs] - delta_i;
|
||||
}
|
||||
for i in 8..WIDTH {
|
||||
state[i] = vars.local_wires[Self::wire_input(i)];
|
||||
}
|
||||
|
||||
let mut state: [F::Extension; SPONGE_WIDTH] = state.try_into().unwrap();
|
||||
let mut round_ctr = 0;
|
||||
|
||||
// First set of full rounds.
|
||||
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
|
||||
<F as Poseidon>::constant_layer_field(&mut state, round_ctr);
|
||||
for i in 0..SPONGE_WIDTH {
|
||||
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
|
||||
constraints.push(state[i] - sbox_in);
|
||||
state[i] = sbox_in;
|
||||
<F as Poseidon<WIDTH>>::constant_layer_field(&mut state, round_ctr);
|
||||
if r != 0 {
|
||||
for i in 0..WIDTH {
|
||||
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
|
||||
constraints.push(state[i] - sbox_in);
|
||||
state[i] = sbox_in;
|
||||
}
|
||||
}
|
||||
<F as Poseidon>::sbox_layer_field(&mut state);
|
||||
state = <F as Poseidon>::mds_layer_field(&state);
|
||||
<F as Poseidon<WIDTH>>::sbox_layer_field(&mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
|
||||
round_ctr += 1;
|
||||
}
|
||||
|
||||
// Partial rounds.
|
||||
<F as Poseidon>::partial_first_constant_layer(&mut state);
|
||||
state = <F as Poseidon>::mds_partial_layer_init(&mut state);
|
||||
<F as Poseidon<WIDTH>>::partial_first_constant_layer(&mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init(&state);
|
||||
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
|
||||
constraints.push(state[0] - sbox_in);
|
||||
state[0] = <F as Poseidon>::sbox_monomial(sbox_in);
|
||||
state[0] +=
|
||||
F::Extension::from_canonical_u64(<F as Poseidon>::FAST_PARTIAL_ROUND_CONSTANTS[r]);
|
||||
state = <F as Poseidon>::mds_partial_layer_fast_field(&state, r);
|
||||
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
|
||||
state[0] += F::Extension::from_canonical_u64(
|
||||
<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r],
|
||||
);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(&state, r);
|
||||
}
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
|
||||
constraints.push(state[0] - sbox_in);
|
||||
state[0] = <F as Poseidon>::sbox_monomial(sbox_in);
|
||||
state =
|
||||
<F as Poseidon>::mds_partial_layer_fast_field(&state, poseidon::N_PARTIAL_ROUNDS - 1);
|
||||
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(
|
||||
&state,
|
||||
poseidon::N_PARTIAL_ROUNDS - 1,
|
||||
);
|
||||
round_ctr += poseidon::N_PARTIAL_ROUNDS;
|
||||
|
||||
// Second set of full rounds.
|
||||
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
|
||||
<F as Poseidon>::constant_layer_field(&mut state, round_ctr);
|
||||
for i in 0..SPONGE_WIDTH {
|
||||
<F as Poseidon<WIDTH>>::constant_layer_field(&mut state, round_ctr);
|
||||
for i in 0..WIDTH {
|
||||
let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)];
|
||||
constraints.push(state[i] - sbox_in);
|
||||
state[i] = sbox_in;
|
||||
}
|
||||
<F as Poseidon>::sbox_layer_field(&mut state);
|
||||
state = <F as Poseidon>::mds_layer_field(&state);
|
||||
<F as Poseidon<WIDTH>>::sbox_layer_field(&mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
|
||||
round_ctr += 1;
|
||||
}
|
||||
|
||||
for i in 0..SPONGE_WIDTH {
|
||||
for i in 0..WIDTH {
|
||||
constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]);
|
||||
}
|
||||
|
||||
@ -170,67 +194,76 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
|
||||
let swap = vars.local_wires[Self::WIRE_SWAP];
|
||||
constraints.push(swap * swap.sub_one());
|
||||
|
||||
let mut state = Vec::with_capacity(SPONGE_WIDTH);
|
||||
// Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`.
|
||||
for i in 0..4 {
|
||||
let a = vars.local_wires[i];
|
||||
let b = vars.local_wires[i + 4];
|
||||
state.push(a + swap * (b - a));
|
||||
}
|
||||
for i in 0..4 {
|
||||
let a = vars.local_wires[i + 4];
|
||||
let b = vars.local_wires[i];
|
||||
state.push(a + swap * (b - a));
|
||||
}
|
||||
for i in 8..SPONGE_WIDTH {
|
||||
state.push(vars.local_wires[i]);
|
||||
let input_lhs = vars.local_wires[Self::wire_input(i)];
|
||||
let input_rhs = vars.local_wires[Self::wire_input(i + 4)];
|
||||
let delta_i = vars.local_wires[Self::wire_delta(i)];
|
||||
constraints.push(swap * (input_rhs - input_lhs) - delta_i);
|
||||
}
|
||||
|
||||
// Compute the possibly-swapped input layer.
|
||||
let mut state = [F::ZERO; WIDTH];
|
||||
for i in 0..4 {
|
||||
let delta_i = vars.local_wires[Self::wire_delta(i)];
|
||||
let input_lhs = Self::wire_input(i);
|
||||
let input_rhs = Self::wire_input(i + 4);
|
||||
state[i] = vars.local_wires[input_lhs] + delta_i;
|
||||
state[i + 4] = vars.local_wires[input_rhs] - delta_i;
|
||||
}
|
||||
for i in 8..WIDTH {
|
||||
state[i] = vars.local_wires[Self::wire_input(i)];
|
||||
}
|
||||
|
||||
let mut state: [F; SPONGE_WIDTH] = state.try_into().unwrap();
|
||||
let mut round_ctr = 0;
|
||||
|
||||
// First set of full rounds.
|
||||
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
|
||||
<F as Poseidon>::constant_layer(&mut state, round_ctr);
|
||||
for i in 0..SPONGE_WIDTH {
|
||||
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
|
||||
constraints.push(state[i] - sbox_in);
|
||||
state[i] = sbox_in;
|
||||
<F as Poseidon<WIDTH>>::constant_layer(&mut state, round_ctr);
|
||||
if r != 0 {
|
||||
for i in 0..WIDTH {
|
||||
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
|
||||
constraints.push(state[i] - sbox_in);
|
||||
state[i] = sbox_in;
|
||||
}
|
||||
}
|
||||
<F as Poseidon>::sbox_layer(&mut state);
|
||||
state = <F as Poseidon>::mds_layer(&state);
|
||||
<F as Poseidon<WIDTH>>::sbox_layer(&mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_layer(&state);
|
||||
round_ctr += 1;
|
||||
}
|
||||
|
||||
// Partial rounds.
|
||||
<F as Poseidon>::partial_first_constant_layer(&mut state);
|
||||
state = <F as Poseidon>::mds_partial_layer_init(&mut state);
|
||||
<F as Poseidon<WIDTH>>::partial_first_constant_layer(&mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init(&state);
|
||||
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
|
||||
constraints.push(state[0] - sbox_in);
|
||||
state[0] = <F as Poseidon>::sbox_monomial(sbox_in);
|
||||
state[0] += F::from_canonical_u64(<F as Poseidon>::FAST_PARTIAL_ROUND_CONSTANTS[r]);
|
||||
state = <F as Poseidon>::mds_partial_layer_fast(&state, r);
|
||||
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
|
||||
state[0] +=
|
||||
F::from_canonical_u64(<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r]);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast(&state, r);
|
||||
}
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
|
||||
constraints.push(state[0] - sbox_in);
|
||||
state[0] = <F as Poseidon>::sbox_monomial(sbox_in);
|
||||
state = <F as Poseidon>::mds_partial_layer_fast(&state, poseidon::N_PARTIAL_ROUNDS - 1);
|
||||
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
|
||||
state =
|
||||
<F as Poseidon<WIDTH>>::mds_partial_layer_fast(&state, poseidon::N_PARTIAL_ROUNDS - 1);
|
||||
round_ctr += poseidon::N_PARTIAL_ROUNDS;
|
||||
|
||||
// Second set of full rounds.
|
||||
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
|
||||
<F as Poseidon>::constant_layer(&mut state, round_ctr);
|
||||
for i in 0..SPONGE_WIDTH {
|
||||
<F as Poseidon<WIDTH>>::constant_layer(&mut state, round_ctr);
|
||||
for i in 0..WIDTH {
|
||||
let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)];
|
||||
constraints.push(state[i] - sbox_in);
|
||||
state[i] = sbox_in;
|
||||
}
|
||||
<F as Poseidon>::sbox_layer(&mut state);
|
||||
state = <F as Poseidon>::mds_layer(&state);
|
||||
<F as Poseidon<WIDTH>>::sbox_layer(&mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_layer(&state);
|
||||
round_ctr += 1;
|
||||
}
|
||||
|
||||
for i in 0..SPONGE_WIDTH {
|
||||
for i in 0..WIDTH {
|
||||
constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]);
|
||||
}
|
||||
|
||||
@ -244,7 +277,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
// The naive method is more efficient if we have enough routed wires for PoseidonMdsGate.
|
||||
let use_mds_gate =
|
||||
builder.config.num_routed_wires >= PoseidonMdsGate::<F, D>::new().num_wires();
|
||||
builder.config.num_routed_wires >= PoseidonMdsGate::<F, D, WIDTH>::new().num_wires();
|
||||
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
@ -252,71 +285,73 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
|
||||
let swap = vars.local_wires[Self::WIRE_SWAP];
|
||||
constraints.push(builder.mul_sub_extension(swap, swap, swap));
|
||||
|
||||
let mut state = Vec::with_capacity(SPONGE_WIDTH);
|
||||
// We need to compute both `if swap {b} else {a}` and `if swap {a} else {b}`.
|
||||
// We will arithmetize them as
|
||||
// swap (b - a) + a
|
||||
// -swap (b - a) + b
|
||||
// so that `b - a` can be used for both.
|
||||
let mut state_first_4 = vec![];
|
||||
let mut state_next_4 = vec![];
|
||||
// Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`.
|
||||
for i in 0..4 {
|
||||
let a = vars.local_wires[i];
|
||||
let b = vars.local_wires[i + 4];
|
||||
let delta = builder.sub_extension(b, a);
|
||||
state_first_4.push(builder.mul_add_extension(swap, delta, a));
|
||||
state_next_4.push(builder.arithmetic_extension(F::NEG_ONE, F::ONE, swap, delta, b));
|
||||
let input_lhs = vars.local_wires[Self::wire_input(i)];
|
||||
let input_rhs = vars.local_wires[Self::wire_input(i + 4)];
|
||||
let delta_i = vars.local_wires[Self::wire_delta(i)];
|
||||
let diff = builder.sub_extension(input_rhs, input_lhs);
|
||||
constraints.push(builder.mul_sub_extension(swap, diff, delta_i));
|
||||
}
|
||||
|
||||
state.extend(state_first_4);
|
||||
state.extend(state_next_4);
|
||||
for i in 8..SPONGE_WIDTH {
|
||||
state.push(vars.local_wires[i]);
|
||||
// Compute the possibly-swapped input layer.
|
||||
let mut state = [builder.zero_extension(); WIDTH];
|
||||
for i in 0..4 {
|
||||
let delta_i = vars.local_wires[Self::wire_delta(i)];
|
||||
let input_lhs = vars.local_wires[Self::wire_input(i)];
|
||||
let input_rhs = vars.local_wires[Self::wire_input(i + 4)];
|
||||
state[i] = builder.add_extension(input_lhs, delta_i);
|
||||
state[i + 4] = builder.sub_extension(input_rhs, delta_i);
|
||||
}
|
||||
for i in 8..WIDTH {
|
||||
state[i] = vars.local_wires[Self::wire_input(i)];
|
||||
}
|
||||
|
||||
let mut state: [ExtensionTarget<D>; SPONGE_WIDTH] = state.try_into().unwrap();
|
||||
let mut round_ctr = 0;
|
||||
|
||||
// First set of full rounds.
|
||||
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
|
||||
<F as Poseidon>::constant_layer_recursive(builder, &mut state, round_ctr);
|
||||
for i in 0..SPONGE_WIDTH {
|
||||
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
|
||||
constraints.push(builder.sub_extension(state[i], sbox_in));
|
||||
state[i] = sbox_in;
|
||||
<F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
|
||||
if r != 0 {
|
||||
for i in 0..WIDTH {
|
||||
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
|
||||
constraints.push(builder.sub_extension(state[i], sbox_in));
|
||||
state[i] = sbox_in;
|
||||
}
|
||||
}
|
||||
<F as Poseidon>::sbox_layer_recursive(builder, &mut state);
|
||||
state = <F as Poseidon>::mds_layer_recursive(builder, &state);
|
||||
<F as Poseidon<WIDTH>>::sbox_layer_recursive(builder, &mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_layer_recursive(builder, &state);
|
||||
round_ctr += 1;
|
||||
}
|
||||
|
||||
// Partial rounds.
|
||||
if use_mds_gate {
|
||||
for r in 0..poseidon::N_PARTIAL_ROUNDS {
|
||||
<F as Poseidon>::constant_layer_recursive(builder, &mut state, round_ctr);
|
||||
<F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
|
||||
constraints.push(builder.sub_extension(state[0], sbox_in));
|
||||
state[0] = <F as Poseidon>::sbox_monomial_recursive(builder, sbox_in);
|
||||
state = <F as Poseidon>::mds_layer_recursive(builder, &state);
|
||||
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
|
||||
state = <F as Poseidon<WIDTH>>::mds_layer_recursive(builder, &state);
|
||||
round_ctr += 1;
|
||||
}
|
||||
} else {
|
||||
<F as Poseidon>::partial_first_constant_layer_recursive(builder, &mut state);
|
||||
state = <F as Poseidon>::mds_partial_layer_init_recursive(builder, &mut state);
|
||||
<F as Poseidon<WIDTH>>::partial_first_constant_layer_recursive(builder, &mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init_recursive(builder, &state);
|
||||
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
|
||||
constraints.push(builder.sub_extension(state[0], sbox_in));
|
||||
state[0] = <F as Poseidon>::sbox_monomial_recursive(builder, sbox_in);
|
||||
state[0] = builder.add_const_extension(
|
||||
state[0],
|
||||
F::from_canonical_u64(<F as Poseidon>::FAST_PARTIAL_ROUND_CONSTANTS[r]),
|
||||
);
|
||||
state = <F as Poseidon>::mds_partial_layer_fast_recursive(builder, &state, r);
|
||||
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
|
||||
let c = <F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r];
|
||||
let c = F::Extension::from_canonical_u64(c);
|
||||
let c = builder.constant_extension(c);
|
||||
state[0] = builder.add_extension(state[0], c);
|
||||
state =
|
||||
<F as Poseidon<WIDTH>>::mds_partial_layer_fast_recursive(builder, &state, r);
|
||||
}
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
|
||||
constraints.push(builder.sub_extension(state[0], sbox_in));
|
||||
state[0] = <F as Poseidon>::sbox_monomial_recursive(builder, sbox_in);
|
||||
state = <F as Poseidon>::mds_partial_layer_fast_recursive(
|
||||
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_recursive(
|
||||
builder,
|
||||
&state,
|
||||
poseidon::N_PARTIAL_ROUNDS - 1,
|
||||
@ -326,18 +361,18 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
|
||||
|
||||
// Second set of full rounds.
|
||||
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
|
||||
<F as Poseidon>::constant_layer_recursive(builder, &mut state, round_ctr);
|
||||
for i in 0..SPONGE_WIDTH {
|
||||
<F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
|
||||
for i in 0..WIDTH {
|
||||
let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)];
|
||||
constraints.push(builder.sub_extension(state[i], sbox_in));
|
||||
state[i] = sbox_in;
|
||||
}
|
||||
<F as Poseidon>::sbox_layer_recursive(builder, &mut state);
|
||||
state = <F as Poseidon>::mds_layer_recursive(builder, &state);
|
||||
<F as Poseidon<WIDTH>>::sbox_layer_recursive(builder, &mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_layer_recursive(builder, &state);
|
||||
round_ctr += 1;
|
||||
}
|
||||
|
||||
for i in 0..SPONGE_WIDTH {
|
||||
for i in 0..WIDTH {
|
||||
constraints
|
||||
.push(builder.sub_extension(state[i], vars.local_wires[Self::wire_output(i)]));
|
||||
}
|
||||
@ -350,7 +385,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
|
||||
gate_index: usize,
|
||||
_local_constants: &[F],
|
||||
) -> Vec<Box<dyn WitnessGenerator<F>>> {
|
||||
let gen = PoseidonGenerator::<F, D> {
|
||||
let gen = PoseidonGenerator::<F, D, WIDTH> {
|
||||
gate_index,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
@ -370,23 +405,31 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for PoseidonGate<F, D> {
|
||||
}
|
||||
|
||||
fn num_constraints(&self) -> usize {
|
||||
SPONGE_WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + SPONGE_WIDTH + 1
|
||||
WIDTH * (poseidon::N_FULL_ROUNDS_TOTAL - 1) + poseidon::N_PARTIAL_ROUNDS + WIDTH + 1 + 4
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PoseidonGenerator<F: Extendable<D> + Poseidon, const D: usize> {
|
||||
struct PoseidonGenerator<
|
||||
F: RichField + Extendable<D> + Poseidon<WIDTH>,
|
||||
const D: usize,
|
||||
const WIDTH: usize,
|
||||
> where
|
||||
[(); WIDTH - 1]:,
|
||||
{
|
||||
gate_index: usize,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D> + Poseidon, const D: usize> SimpleGenerator<F>
|
||||
for PoseidonGenerator<F, D>
|
||||
impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize>
|
||||
SimpleGenerator<F> for PoseidonGenerator<F, D, WIDTH>
|
||||
where
|
||||
[(); WIDTH - 1]:,
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
(0..SPONGE_WIDTH)
|
||||
.map(|i| PoseidonGate::<F, D>::wire_input(i))
|
||||
.chain(Some(PoseidonGate::<F, D>::WIRE_SWAP))
|
||||
(0..WIDTH)
|
||||
.map(|i| PoseidonGate::<F, D, WIDTH>::wire_input(i))
|
||||
.chain(Some(PoseidonGate::<F, D, WIDTH>::WIRE_SWAP))
|
||||
.map(|input| Target::wire(self.gate_index, input))
|
||||
.collect()
|
||||
}
|
||||
@ -397,87 +440,94 @@ impl<F: RichField + Extendable<D> + Poseidon, const D: usize> SimpleGenerator<F>
|
||||
input,
|
||||
};
|
||||
|
||||
let mut state = (0..SPONGE_WIDTH)
|
||||
.map(|i| {
|
||||
witness.get_wire(Wire {
|
||||
gate: self.gate_index,
|
||||
input: PoseidonGate::<F, D>::wire_input(i),
|
||||
})
|
||||
})
|
||||
let mut state = (0..WIDTH)
|
||||
.map(|i| witness.get_wire(local_wire(PoseidonGate::<F, D, WIDTH>::wire_input(i))))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let swap_value = witness.get_wire(Wire {
|
||||
gate: self.gate_index,
|
||||
input: PoseidonGate::<F, D>::WIRE_SWAP,
|
||||
});
|
||||
let swap_value = witness.get_wire(local_wire(PoseidonGate::<F, D, WIDTH>::WIRE_SWAP));
|
||||
debug_assert!(swap_value == F::ZERO || swap_value == F::ONE);
|
||||
|
||||
for i in 0..4 {
|
||||
let delta_i = swap_value * (state[i + 4] - state[i]);
|
||||
out_buffer.set_wire(
|
||||
local_wire(PoseidonGate::<F, D, WIDTH>::wire_delta(i)),
|
||||
delta_i,
|
||||
);
|
||||
}
|
||||
|
||||
if swap_value == F::ONE {
|
||||
for i in 0..4 {
|
||||
state.swap(i, 4 + i);
|
||||
}
|
||||
}
|
||||
|
||||
let mut state: [F; SPONGE_WIDTH] = state.try_into().unwrap();
|
||||
let mut state: [F; WIDTH] = state.try_into().unwrap();
|
||||
let mut round_ctr = 0;
|
||||
|
||||
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
|
||||
<F as Poseidon>::constant_layer_field(&mut state, round_ctr);
|
||||
for i in 0..SPONGE_WIDTH {
|
||||
out_buffer.set_wire(
|
||||
local_wire(PoseidonGate::<F, D>::wire_full_sbox_0(r, i)),
|
||||
state[i],
|
||||
);
|
||||
<F as Poseidon<WIDTH>>::constant_layer_field(&mut state, round_ctr);
|
||||
if r != 0 {
|
||||
for i in 0..WIDTH {
|
||||
out_buffer.set_wire(
|
||||
local_wire(PoseidonGate::<F, D, WIDTH>::wire_full_sbox_0(r, i)),
|
||||
state[i],
|
||||
);
|
||||
}
|
||||
}
|
||||
<F as Poseidon>::sbox_layer_field(&mut state);
|
||||
state = <F as Poseidon>::mds_layer_field(&state);
|
||||
<F as Poseidon<WIDTH>>::sbox_layer_field(&mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
|
||||
round_ctr += 1;
|
||||
}
|
||||
|
||||
<F as Poseidon>::partial_first_constant_layer(&mut state);
|
||||
state = <F as Poseidon>::mds_partial_layer_init(&mut state);
|
||||
<F as Poseidon<WIDTH>>::partial_first_constant_layer(&mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init(&state);
|
||||
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
|
||||
out_buffer.set_wire(
|
||||
local_wire(PoseidonGate::<F, D>::wire_partial_sbox(r)),
|
||||
local_wire(PoseidonGate::<F, D, WIDTH>::wire_partial_sbox(r)),
|
||||
state[0],
|
||||
);
|
||||
state[0] = <F as Poseidon>::sbox_monomial(state[0]);
|
||||
state[0] += F::from_canonical_u64(<F as Poseidon>::FAST_PARTIAL_ROUND_CONSTANTS[r]);
|
||||
state = <F as Poseidon>::mds_partial_layer_fast_field(&state, r);
|
||||
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(state[0]);
|
||||
state[0] +=
|
||||
F::from_canonical_u64(<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r]);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(&state, r);
|
||||
}
|
||||
out_buffer.set_wire(
|
||||
local_wire(PoseidonGate::<F, D>::wire_partial_sbox(
|
||||
local_wire(PoseidonGate::<F, D, WIDTH>::wire_partial_sbox(
|
||||
poseidon::N_PARTIAL_ROUNDS - 1,
|
||||
)),
|
||||
state[0],
|
||||
);
|
||||
state[0] = <F as Poseidon>::sbox_monomial(state[0]);
|
||||
state =
|
||||
<F as Poseidon>::mds_partial_layer_fast_field(&state, poseidon::N_PARTIAL_ROUNDS - 1);
|
||||
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(state[0]);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(
|
||||
&state,
|
||||
poseidon::N_PARTIAL_ROUNDS - 1,
|
||||
);
|
||||
round_ctr += poseidon::N_PARTIAL_ROUNDS;
|
||||
|
||||
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
|
||||
<F as Poseidon>::constant_layer_field(&mut state, round_ctr);
|
||||
for i in 0..SPONGE_WIDTH {
|
||||
<F as Poseidon<WIDTH>>::constant_layer_field(&mut state, round_ctr);
|
||||
for i in 0..WIDTH {
|
||||
out_buffer.set_wire(
|
||||
local_wire(PoseidonGate::<F, D>::wire_full_sbox_1(r, i)),
|
||||
local_wire(PoseidonGate::<F, D, WIDTH>::wire_full_sbox_1(r, i)),
|
||||
state[i],
|
||||
);
|
||||
}
|
||||
<F as Poseidon>::sbox_layer_field(&mut state);
|
||||
state = <F as Poseidon>::mds_layer_field(&state);
|
||||
<F as Poseidon<WIDTH>>::sbox_layer_field(&mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
|
||||
round_ctr += 1;
|
||||
}
|
||||
|
||||
for i in 0..SPONGE_WIDTH {
|
||||
out_buffer.set_wire(local_wire(PoseidonGate::<F, D>::wire_output(i)), state[i]);
|
||||
for i in 0..WIDTH {
|
||||
out_buffer.set_wire(
|
||||
local_wire(PoseidonGate::<F, D, WIDTH>::wire_output(i)),
|
||||
state[i],
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::convert::TryInto;
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
use crate::field::field_types::Field;
|
||||
@ -493,6 +543,29 @@ mod tests {
|
||||
use crate::plonk::circuit_data::CircuitConfig;
|
||||
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
|
||||
|
||||
#[test]
|
||||
fn wire_indices() {
|
||||
type F = GoldilocksField;
|
||||
const WIDTH: usize = 12;
|
||||
type Gate = PoseidonGate<F, 4, WIDTH>;
|
||||
|
||||
assert_eq!(Gate::wire_input(0), 0);
|
||||
assert_eq!(Gate::wire_input(11), 11);
|
||||
assert_eq!(Gate::wire_output(0), 12);
|
||||
assert_eq!(Gate::wire_output(11), 23);
|
||||
assert_eq!(Gate::WIRE_SWAP, 24);
|
||||
assert_eq!(Gate::wire_delta(0), 25);
|
||||
assert_eq!(Gate::wire_delta(3), 28);
|
||||
assert_eq!(Gate::wire_full_sbox_0(1, 0), 29);
|
||||
assert_eq!(Gate::wire_full_sbox_0(3, 0), 53);
|
||||
assert_eq!(Gate::wire_full_sbox_0(3, 11), 64);
|
||||
assert_eq!(Gate::wire_partial_sbox(0), 65);
|
||||
assert_eq!(Gate::wire_partial_sbox(21), 86);
|
||||
assert_eq!(Gate::wire_full_sbox_1(0, 0), 87);
|
||||
assert_eq!(Gate::wire_full_sbox_1(3, 0), 123);
|
||||
assert_eq!(Gate::wire_full_sbox_1(3, 11), 134);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generated_output() {
|
||||
const D: usize = 2;
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
use std::convert::TryInto;
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::Range;
|
||||
|
||||
@ -6,9 +5,8 @@ use crate::field::extension_field::algebra::ExtensionAlgebra;
|
||||
use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget};
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::extension_field::FieldExtension;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::hash::hashing::SPONGE_WIDTH;
|
||||
use crate::hash::poseidon::Poseidon;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
@ -17,11 +15,21 @@ use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PoseidonMdsGate<F: Extendable<D> + Poseidon, const D: usize> {
|
||||
pub struct PoseidonMdsGate<
|
||||
F: RichField + Extendable<D> + Poseidon<WIDTH>,
|
||||
const D: usize,
|
||||
const WIDTH: usize,
|
||||
> where
|
||||
[(); WIDTH - 1]:,
|
||||
{
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
|
||||
impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize>
|
||||
PoseidonMdsGate<F, D, WIDTH>
|
||||
where
|
||||
[(); WIDTH - 1]:,
|
||||
{
|
||||
pub fn new() -> Self {
|
||||
PoseidonMdsGate {
|
||||
_phantom: PhantomData,
|
||||
@ -29,13 +37,13 @@ impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
|
||||
}
|
||||
|
||||
pub fn wires_input(i: usize) -> Range<usize> {
|
||||
assert!(i < SPONGE_WIDTH);
|
||||
assert!(i < WIDTH);
|
||||
i * D..(i + 1) * D
|
||||
}
|
||||
|
||||
pub fn wires_output(i: usize) -> Range<usize> {
|
||||
assert!(i < SPONGE_WIDTH);
|
||||
(SPONGE_WIDTH + i) * D..(SPONGE_WIDTH + i + 1) * D
|
||||
assert!(i < WIDTH);
|
||||
(WIDTH + i) * D..(WIDTH + i + 1) * D
|
||||
}
|
||||
|
||||
// Following are methods analogous to ones in `Poseidon`, but for extension algebras.
|
||||
@ -43,14 +51,15 @@ impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
|
||||
/// Same as `mds_row_shf` for an extension algebra of `F`.
|
||||
fn mds_row_shf_algebra(
|
||||
r: usize,
|
||||
v: &[ExtensionAlgebra<F::Extension, D>; SPONGE_WIDTH],
|
||||
v: &[ExtensionAlgebra<F::Extension, D>; WIDTH],
|
||||
) -> ExtensionAlgebra<F::Extension, D> {
|
||||
debug_assert!(r < SPONGE_WIDTH);
|
||||
debug_assert!(r < WIDTH);
|
||||
let mut res = ExtensionAlgebra::ZERO;
|
||||
|
||||
for i in 0..SPONGE_WIDTH {
|
||||
let coeff = F::Extension::from_canonical_u64(1 << <F as Poseidon>::MDS_MATRIX_EXPS[i]);
|
||||
res += v[(i + r) % SPONGE_WIDTH].scalar_mul(coeff);
|
||||
for i in 0..WIDTH {
|
||||
let coeff =
|
||||
F::Extension::from_canonical_u64(1 << <F as Poseidon<WIDTH>>::MDS_MATRIX_EXPS[i]);
|
||||
res += v[(i + r) % WIDTH].scalar_mul(coeff);
|
||||
}
|
||||
|
||||
res
|
||||
@ -60,16 +69,16 @@ impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
|
||||
fn mds_row_shf_algebra_recursive(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
r: usize,
|
||||
v: &[ExtensionAlgebraTarget<D>; SPONGE_WIDTH],
|
||||
v: &[ExtensionAlgebraTarget<D>; WIDTH],
|
||||
) -> ExtensionAlgebraTarget<D> {
|
||||
debug_assert!(r < SPONGE_WIDTH);
|
||||
debug_assert!(r < WIDTH);
|
||||
let mut res = builder.zero_ext_algebra();
|
||||
|
||||
for i in 0..SPONGE_WIDTH {
|
||||
for i in 0..WIDTH {
|
||||
let coeff = builder.constant_extension(F::Extension::from_canonical_u64(
|
||||
1 << <F as Poseidon>::MDS_MATRIX_EXPS[i],
|
||||
1 << <F as Poseidon<WIDTH>>::MDS_MATRIX_EXPS[i],
|
||||
));
|
||||
res = builder.scalar_mul_add_ext_algebra(coeff, v[(i + r) % SPONGE_WIDTH], res);
|
||||
res = builder.scalar_mul_add_ext_algebra(coeff, v[(i + r) % WIDTH], res);
|
||||
}
|
||||
|
||||
res
|
||||
@ -77,11 +86,11 @@ impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
|
||||
|
||||
/// Same as `mds_layer` for an extension algebra of `F`.
|
||||
fn mds_layer_algebra(
|
||||
state: &[ExtensionAlgebra<F::Extension, D>; SPONGE_WIDTH],
|
||||
) -> [ExtensionAlgebra<F::Extension, D>; SPONGE_WIDTH] {
|
||||
let mut result = [ExtensionAlgebra::ZERO; SPONGE_WIDTH];
|
||||
state: &[ExtensionAlgebra<F::Extension, D>; WIDTH],
|
||||
) -> [ExtensionAlgebra<F::Extension, D>; WIDTH] {
|
||||
let mut result = [ExtensionAlgebra::ZERO; WIDTH];
|
||||
|
||||
for r in 0..SPONGE_WIDTH {
|
||||
for r in 0..WIDTH {
|
||||
result[r] = Self::mds_row_shf_algebra(r, state);
|
||||
}
|
||||
|
||||
@ -91,11 +100,11 @@ impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
|
||||
/// Same as `mds_layer_recursive` for an extension algebra of `F`.
|
||||
fn mds_layer_algebra_recursive(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
state: &[ExtensionAlgebraTarget<D>; SPONGE_WIDTH],
|
||||
) -> [ExtensionAlgebraTarget<D>; SPONGE_WIDTH] {
|
||||
let mut result = [builder.zero_ext_algebra(); SPONGE_WIDTH];
|
||||
state: &[ExtensionAlgebraTarget<D>; WIDTH],
|
||||
) -> [ExtensionAlgebraTarget<D>; WIDTH] {
|
||||
let mut result = [builder.zero_ext_algebra(); WIDTH];
|
||||
|
||||
for r in 0..SPONGE_WIDTH {
|
||||
for r in 0..WIDTH {
|
||||
result[r] = Self::mds_row_shf_algebra_recursive(builder, r, state);
|
||||
}
|
||||
|
||||
@ -103,13 +112,17 @@ impl<F: Extendable<D> + Poseidon, const D: usize> PoseidonMdsGate<F, D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate<F, D> {
|
||||
impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize> Gate<F, D>
|
||||
for PoseidonMdsGate<F, D, WIDTH>
|
||||
where
|
||||
[(); WIDTH - 1]:,
|
||||
{
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}<WIDTH={}>", self, SPONGE_WIDTH)
|
||||
format!("{:?}<WIDTH={}>", self, WIDTH)
|
||||
}
|
||||
|
||||
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
|
||||
let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH)
|
||||
let inputs: [_; WIDTH] = (0..WIDTH)
|
||||
.map(|i| vars.get_local_ext_algebra(Self::wires_input(i)))
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
@ -117,7 +130,7 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
|
||||
|
||||
let computed_outputs = Self::mds_layer_algebra(&inputs);
|
||||
|
||||
(0..SPONGE_WIDTH)
|
||||
(0..WIDTH)
|
||||
.map(|i| vars.get_local_ext_algebra(Self::wires_output(i)))
|
||||
.zip(computed_outputs)
|
||||
.flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array())
|
||||
@ -125,7 +138,7 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH)
|
||||
let inputs: [_; WIDTH] = (0..WIDTH)
|
||||
.map(|i| vars.get_local_ext(Self::wires_input(i)))
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
@ -133,7 +146,7 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
|
||||
|
||||
let computed_outputs = F::mds_layer_field(&inputs);
|
||||
|
||||
(0..SPONGE_WIDTH)
|
||||
(0..WIDTH)
|
||||
.map(|i| vars.get_local_ext(Self::wires_output(i)))
|
||||
.zip(computed_outputs)
|
||||
.flat_map(|(out, computed_out)| (out - computed_out).to_basefield_array())
|
||||
@ -145,7 +158,7 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
vars: EvaluationTargets<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH)
|
||||
let inputs: [_; WIDTH] = (0..WIDTH)
|
||||
.map(|i| vars.get_local_ext_algebra(Self::wires_input(i)))
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
@ -153,7 +166,7 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
|
||||
|
||||
let computed_outputs = Self::mds_layer_algebra_recursive(builder, &inputs);
|
||||
|
||||
(0..SPONGE_WIDTH)
|
||||
(0..WIDTH)
|
||||
.map(|i| vars.get_local_ext_algebra(Self::wires_output(i)))
|
||||
.zip(computed_outputs)
|
||||
.flat_map(|(out, computed_out)| {
|
||||
@ -169,12 +182,12 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
|
||||
gate_index: usize,
|
||||
_local_constants: &[F],
|
||||
) -> Vec<Box<dyn WitnessGenerator<F>>> {
|
||||
let gen = PoseidonMdsGenerator::<D> { gate_index };
|
||||
let gen = PoseidonMdsGenerator::<D, WIDTH> { gate_index };
|
||||
vec![Box::new(gen.adapter())]
|
||||
}
|
||||
|
||||
fn num_wires(&self) -> usize {
|
||||
2 * D * SPONGE_WIDTH
|
||||
2 * D * WIDTH
|
||||
}
|
||||
|
||||
fn num_constants(&self) -> usize {
|
||||
@ -186,20 +199,30 @@ impl<F: Extendable<D> + Poseidon, const D: usize> Gate<F, D> for PoseidonMdsGate
|
||||
}
|
||||
|
||||
fn num_constraints(&self) -> usize {
|
||||
SPONGE_WIDTH * D
|
||||
WIDTH * D
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct PoseidonMdsGenerator<const D: usize> {
|
||||
struct PoseidonMdsGenerator<const D: usize, const WIDTH: usize>
|
||||
where
|
||||
[(); WIDTH - 1]:,
|
||||
{
|
||||
gate_index: usize,
|
||||
}
|
||||
|
||||
impl<F: Extendable<D> + Poseidon, const D: usize> SimpleGenerator<F> for PoseidonMdsGenerator<D> {
|
||||
impl<F: RichField + Extendable<D> + Poseidon<WIDTH>, const D: usize, const WIDTH: usize>
|
||||
SimpleGenerator<F> for PoseidonMdsGenerator<D, WIDTH>
|
||||
where
|
||||
[(); WIDTH - 1]:,
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
(0..SPONGE_WIDTH)
|
||||
(0..WIDTH)
|
||||
.flat_map(|i| {
|
||||
Target::wires_from_range(self.gate_index, PoseidonMdsGate::<F, D>::wires_input(i))
|
||||
Target::wires_from_range(
|
||||
self.gate_index,
|
||||
PoseidonMdsGate::<F, D, WIDTH>::wires_input(i),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@ -210,8 +233,8 @@ impl<F: Extendable<D> + Poseidon, const D: usize> SimpleGenerator<F> for Poseido
|
||||
let get_local_ext =
|
||||
|wire_range| witness.get_extension_target(get_local_get_target(wire_range));
|
||||
|
||||
let inputs: [_; SPONGE_WIDTH] = (0..SPONGE_WIDTH)
|
||||
.map(|i| get_local_ext(PoseidonMdsGate::<F, D>::wires_input(i)))
|
||||
let inputs: [_; WIDTH] = (0..WIDTH)
|
||||
.map(|i| get_local_ext(PoseidonMdsGate::<F, D, WIDTH>::wires_input(i)))
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
@ -220,7 +243,7 @@ impl<F: Extendable<D> + Poseidon, const D: usize> SimpleGenerator<F> for Poseido
|
||||
|
||||
for (i, &out) in outputs.iter().enumerate() {
|
||||
out_buffer.set_extension_target(
|
||||
get_local_get_target(PoseidonMdsGate::<F, D>::wires_output(i)),
|
||||
get_local_get_target(PoseidonMdsGate::<F, D, WIDTH>::wires_output(i)),
|
||||
out,
|
||||
);
|
||||
}
|
||||
@ -232,21 +255,19 @@ mod tests {
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
|
||||
use crate::gates::poseidon_mds::PoseidonMdsGate;
|
||||
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
|
||||
use crate::hash::hashing::SPONGE_WIDTH;
|
||||
|
||||
#[test]
|
||||
fn low_degree() {
|
||||
type F = GoldilocksField;
|
||||
let gate = PoseidonMdsGate::<F, 4>::new();
|
||||
let gate = PoseidonMdsGate::<F, 4, SPONGE_WIDTH>::new();
|
||||
test_low_degree(gate)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_fns() -> anyhow::Result<()> {
|
||||
const D: usize = 2;
|
||||
type C = PoseidonGoldilocksConfig;
|
||||
type F = <C as GenericConfig<D>>::F;
|
||||
let gate = PoseidonMdsGate::<F, D>::new();
|
||||
test_eval_fns::<F, C, _, D>(gate)
|
||||
type F = GoldilocksField;
|
||||
let gate = PoseidonMdsGate::<F, 4, SPONGE_WIDTH>::new();
|
||||
test_eval_fns(gate)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::Field;
|
||||
@ -14,76 +16,65 @@ use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
|
||||
/// A gate for checking that a particular element of a list matches a given value.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub(crate) struct RandomAccessGate<F: Extendable<D>, const D: usize> {
|
||||
pub vec_size: usize,
|
||||
pub(crate) struct RandomAccessGate<F: Extendable<D>, const D: usize> {
|
||||
pub bits: usize,
|
||||
pub num_copies: usize,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: Extendable<D>, const D: usize> RandomAccessGate<F, D> {
|
||||
pub fn new(num_copies: usize, vec_size: usize) -> Self {
|
||||
impl<F: Extendable<D>, const D: usize> RandomAccessGate<F, D> {
|
||||
fn new(num_copies: usize, bits: usize) -> Self {
|
||||
Self {
|
||||
vec_size,
|
||||
bits,
|
||||
num_copies,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_from_config(config: &CircuitConfig, vec_size: usize) -> Self {
|
||||
let num_copies = Self::max_num_copies(config.num_routed_wires, config.num_wires, vec_size);
|
||||
Self::new(num_copies, vec_size)
|
||||
pub fn new_from_config(config: &CircuitConfig, bits: usize) -> Self {
|
||||
let vec_size = 1 << bits;
|
||||
// Need `(2 + vec_size) * num_copies` routed wires
|
||||
let max_copies = (config.num_routed_wires / (2 + vec_size)).min(
|
||||
// Need `(2 + vec_size + bits) * num_copies` wires
|
||||
config.num_wires / (2 + vec_size + bits),
|
||||
);
|
||||
Self::new(max_copies, bits)
|
||||
}
|
||||
|
||||
pub fn max_num_copies(num_routed_wires: usize, num_wires: usize, vec_size: usize) -> usize {
|
||||
// Need `(2 + vec_size) * num_copies` routed wires
|
||||
(num_routed_wires / (2 + vec_size)).min(
|
||||
// Need `(2 + 3*vec_size) * num_copies` wires
|
||||
num_wires / (2 + 3 * vec_size),
|
||||
)
|
||||
fn vec_size(&self) -> usize {
|
||||
1 << self.bits
|
||||
}
|
||||
|
||||
pub fn wire_access_index(&self, copy: usize) -> usize {
|
||||
debug_assert!(copy < self.num_copies);
|
||||
(2 + self.vec_size) * copy
|
||||
(2 + self.vec_size()) * copy
|
||||
}
|
||||
|
||||
pub fn wire_claimed_element(&self, copy: usize) -> usize {
|
||||
debug_assert!(copy < self.num_copies);
|
||||
(2 + self.vec_size) * copy + 1
|
||||
(2 + self.vec_size()) * copy + 1
|
||||
}
|
||||
|
||||
pub fn wire_list_item(&self, i: usize, copy: usize) -> usize {
|
||||
debug_assert!(i < self.vec_size);
|
||||
debug_assert!(i < self.vec_size());
|
||||
debug_assert!(copy < self.num_copies);
|
||||
(2 + self.vec_size) * copy + 2 + i
|
||||
(2 + self.vec_size()) * copy + 2 + i
|
||||
}
|
||||
|
||||
fn start_of_intermediate_wires(&self) -> usize {
|
||||
(2 + self.vec_size) * self.num_copies
|
||||
(2 + self.vec_size()) * self.num_copies
|
||||
}
|
||||
|
||||
pub(crate) fn num_routed_wires(&self) -> usize {
|
||||
self.start_of_intermediate_wires()
|
||||
}
|
||||
|
||||
/// An intermediate wire for a dummy variable used to show equality.
|
||||
/// The prover sets this to 1/(x-y) if x != y, or to an arbitrary value if
|
||||
/// x == y.
|
||||
pub fn wire_equality_dummy_for_index(&self, i: usize, copy: usize) -> usize {
|
||||
debug_assert!(i < self.vec_size);
|
||||
/// An intermediate wire where the prover gives the (purported) binary decomposition of the
|
||||
/// index.
|
||||
pub fn wire_bit(&self, i: usize, copy: usize) -> usize {
|
||||
debug_assert!(i < self.bits);
|
||||
debug_assert!(copy < self.num_copies);
|
||||
self.start_of_intermediate_wires() + copy * self.vec_size + i
|
||||
}
|
||||
|
||||
/// An intermediate wire for the "index_matches" variable (1 if the current index is the index at
|
||||
/// which to compare, 0 otherwise).
|
||||
pub fn wire_index_matches_for_index(&self, i: usize, copy: usize) -> usize {
|
||||
debug_assert!(i < self.vec_size);
|
||||
debug_assert!(copy < self.num_copies);
|
||||
self.start_of_intermediate_wires()
|
||||
+ self.vec_size * self.num_copies
|
||||
+ self.vec_size * copy
|
||||
+ i
|
||||
self.start_of_intermediate_wires() + copy * self.bits + i
|
||||
}
|
||||
}
|
||||
|
||||
@ -97,23 +88,38 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
|
||||
|
||||
for copy in 0..self.num_copies {
|
||||
let access_index = vars.local_wires[self.wire_access_index(copy)];
|
||||
let list_items = (0..self.vec_size)
|
||||
let mut list_items = (0..self.vec_size())
|
||||
.map(|i| vars.local_wires[self.wire_list_item(i, copy)])
|
||||
.collect::<Vec<_>>();
|
||||
let claimed_element = vars.local_wires[self.wire_claimed_element(copy)];
|
||||
let bits = (0..self.bits)
|
||||
.map(|i| vars.local_wires[self.wire_bit(i, copy)])
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for i in 0..self.vec_size {
|
||||
let cur_index = F::Extension::from_canonical_usize(i);
|
||||
let difference = cur_index - access_index;
|
||||
let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i, copy)];
|
||||
let index_matches = vars.local_wires[self.wire_index_matches_for_index(i, copy)];
|
||||
|
||||
// The two index equality constraints.
|
||||
constraints.push(difference * equality_dummy - (F::Extension::ONE - index_matches));
|
||||
constraints.push(index_matches * difference);
|
||||
// Value equality constraint.
|
||||
constraints.push((list_items[i] - claimed_element) * index_matches);
|
||||
// Assert that each bit wire value is indeed boolean.
|
||||
for &b in &bits {
|
||||
constraints.push(b * (b - F::Extension::ONE));
|
||||
}
|
||||
|
||||
// Assert that the binary decomposition was correct.
|
||||
let reconstructed_index = bits
|
||||
.iter()
|
||||
.rev()
|
||||
.fold(F::Extension::ZERO, |acc, &b| acc.double() + b);
|
||||
constraints.push(reconstructed_index - access_index);
|
||||
|
||||
// Repeatedly fold the list, selecting the left or right item from each pair based on
|
||||
// the corresponding bit.
|
||||
for b in bits {
|
||||
list_items = list_items
|
||||
.iter()
|
||||
.tuples()
|
||||
.map(|(&x, &y)| x + b * (y - x))
|
||||
.collect()
|
||||
}
|
||||
|
||||
debug_assert_eq!(list_items.len(), 1);
|
||||
constraints.push(list_items[0] - claimed_element);
|
||||
}
|
||||
|
||||
constraints
|
||||
@ -124,23 +130,35 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
|
||||
|
||||
for copy in 0..self.num_copies {
|
||||
let access_index = vars.local_wires[self.wire_access_index(copy)];
|
||||
let list_items = (0..self.vec_size)
|
||||
let mut list_items = (0..self.vec_size())
|
||||
.map(|i| vars.local_wires[self.wire_list_item(i, copy)])
|
||||
.collect::<Vec<_>>();
|
||||
let claimed_element = vars.local_wires[self.wire_claimed_element(copy)];
|
||||
let bits = (0..self.bits)
|
||||
.map(|i| vars.local_wires[self.wire_bit(i, copy)])
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for i in 0..self.vec_size {
|
||||
let cur_index = F::from_canonical_usize(i);
|
||||
let difference = cur_index - access_index;
|
||||
let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i, copy)];
|
||||
let index_matches = vars.local_wires[self.wire_index_matches_for_index(i, copy)];
|
||||
|
||||
// The two index equality constraints.
|
||||
constraints.push(difference * equality_dummy - (F::ONE - index_matches));
|
||||
constraints.push(index_matches * difference);
|
||||
// Value equality constraint.
|
||||
constraints.push((list_items[i] - claimed_element) * index_matches);
|
||||
// Assert that each bit wire value is indeed boolean.
|
||||
for &b in &bits {
|
||||
constraints.push(b * (b - F::ONE));
|
||||
}
|
||||
|
||||
// Assert that the binary decomposition was correct.
|
||||
let reconstructed_index = bits.iter().rev().fold(F::ZERO, |acc, &b| acc.double() + b);
|
||||
constraints.push(reconstructed_index - access_index);
|
||||
|
||||
// Repeatedly fold the list, selecting the left or right item from each pair based on
|
||||
// the corresponding bit.
|
||||
for b in bits {
|
||||
list_items = list_items
|
||||
.iter()
|
||||
.tuples()
|
||||
.map(|(&x, &y)| x + b * (y - x))
|
||||
.collect()
|
||||
}
|
||||
|
||||
debug_assert_eq!(list_items.len(), 1);
|
||||
constraints.push(list_items[0] - claimed_element);
|
||||
}
|
||||
|
||||
constraints
|
||||
@ -151,36 +169,44 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
vars: EvaluationTargets<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let zero = builder.zero_extension();
|
||||
let two = builder.two_extension();
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
for copy in 0..self.num_copies {
|
||||
let access_index = vars.local_wires[self.wire_access_index(copy)];
|
||||
let list_items = (0..self.vec_size)
|
||||
let mut list_items = (0..self.vec_size())
|
||||
.map(|i| vars.local_wires[self.wire_list_item(i, copy)])
|
||||
.collect::<Vec<_>>();
|
||||
let claimed_element = vars.local_wires[self.wire_claimed_element(copy)];
|
||||
let bits = (0..self.bits)
|
||||
.map(|i| vars.local_wires[self.wire_bit(i, copy)])
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for i in 0..self.vec_size {
|
||||
let cur_index_ext = F::Extension::from_canonical_usize(i);
|
||||
let cur_index = builder.constant_extension(cur_index_ext);
|
||||
let difference = builder.sub_extension(cur_index, access_index);
|
||||
let equality_dummy = vars.local_wires[self.wire_equality_dummy_for_index(i, copy)];
|
||||
let index_matches = vars.local_wires[self.wire_index_matches_for_index(i, copy)];
|
||||
|
||||
let one = builder.one_extension();
|
||||
let not_index_matches = builder.sub_extension(one, index_matches);
|
||||
let first_equality_constraint =
|
||||
builder.mul_sub_extension(difference, equality_dummy, not_index_matches);
|
||||
constraints.push(first_equality_constraint);
|
||||
|
||||
let second_equality_constraint = builder.mul_extension(index_matches, difference);
|
||||
constraints.push(second_equality_constraint);
|
||||
|
||||
// Output constraint.
|
||||
let diff = builder.sub_extension(list_items[i], claimed_element);
|
||||
let conditional_diff = builder.mul_extension(index_matches, diff);
|
||||
constraints.push(conditional_diff);
|
||||
// Assert that each bit wire value is indeed boolean.
|
||||
for &b in &bits {
|
||||
constraints.push(builder.mul_sub_extension(b, b, b));
|
||||
}
|
||||
|
||||
// Assert that the binary decomposition was correct.
|
||||
let reconstructed_index = bits
|
||||
.iter()
|
||||
.rev()
|
||||
.fold(zero, |acc, &b| builder.mul_add_extension(acc, two, b));
|
||||
constraints.push(builder.sub_extension(reconstructed_index, access_index));
|
||||
|
||||
// Repeatedly fold the list, selecting the left or right item from each pair based on
|
||||
// the corresponding bit.
|
||||
for b in bits {
|
||||
list_items = list_items
|
||||
.iter()
|
||||
.tuples()
|
||||
.map(|(&x, &y)| builder.select_ext_generalized(b, y, x))
|
||||
.collect()
|
||||
}
|
||||
|
||||
debug_assert_eq!(list_items.len(), 1);
|
||||
constraints.push(builder.sub_extension(list_items[0], claimed_element));
|
||||
}
|
||||
|
||||
constraints
|
||||
@ -207,7 +233,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
|
||||
}
|
||||
|
||||
fn num_wires(&self) -> usize {
|
||||
self.wire_index_matches_for_index(self.vec_size - 1, self.num_copies - 1) + 1
|
||||
self.wire_bit(self.bits - 1, self.num_copies - 1) + 1
|
||||
}
|
||||
|
||||
fn num_constants(&self) -> usize {
|
||||
@ -215,11 +241,12 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for RandomAccessGate<F, D> {
|
||||
}
|
||||
|
||||
fn degree(&self) -> usize {
|
||||
2
|
||||
self.bits + 1
|
||||
}
|
||||
|
||||
fn num_constraints(&self) -> usize {
|
||||
3 * self.num_copies * self.vec_size
|
||||
let constraints_per_copy = self.bits + 2;
|
||||
self.num_copies * constraints_per_copy
|
||||
}
|
||||
}
|
||||
|
||||
@ -234,10 +261,8 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for RandomAccessGenera
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
let local_target = |input| Target::wire(self.gate_index, input);
|
||||
|
||||
let mut deps = Vec::new();
|
||||
deps.push(local_target(self.gate.wire_access_index(self.copy)));
|
||||
deps.push(local_target(self.gate.wire_claimed_element(self.copy)));
|
||||
for i in 0..self.gate.vec_size {
|
||||
let mut deps = vec![local_target(self.gate.wire_access_index(self.copy))];
|
||||
for i in 0..self.gate.vec_size() {
|
||||
deps.push(local_target(self.gate.wire_list_item(i, self.copy)));
|
||||
}
|
||||
deps
|
||||
@ -250,11 +275,12 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for RandomAccessGenera
|
||||
};
|
||||
|
||||
let get_local_wire = |input| witness.get_wire(local_wire(input));
|
||||
let mut set_local_wire = |input, value| out_buffer.set_wire(local_wire(input), value);
|
||||
|
||||
// Compute the new vector and the values for equality_dummy and index_matches
|
||||
let vec_size = self.gate.vec_size;
|
||||
let access_index_f = get_local_wire(self.gate.wire_access_index(self.copy));
|
||||
let copy = self.copy;
|
||||
let vec_size = self.gate.vec_size();
|
||||
|
||||
let access_index_f = get_local_wire(self.gate.wire_access_index(copy));
|
||||
let access_index = access_index_f.to_canonical_u64() as usize;
|
||||
debug_assert!(
|
||||
access_index < vec_size,
|
||||
@ -263,22 +289,14 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for RandomAccessGenera
|
||||
vec_size
|
||||
);
|
||||
|
||||
for i in 0..vec_size {
|
||||
let equality_dummy_wire =
|
||||
local_wire(self.gate.wire_equality_dummy_for_index(i, self.copy));
|
||||
let index_matches_wire =
|
||||
local_wire(self.gate.wire_index_matches_for_index(i, self.copy));
|
||||
set_local_wire(
|
||||
self.gate.wire_claimed_element(copy),
|
||||
get_local_wire(self.gate.wire_list_item(access_index, copy)),
|
||||
);
|
||||
|
||||
if i == access_index {
|
||||
out_buffer.set_wire(equality_dummy_wire, F::ONE);
|
||||
out_buffer.set_wire(index_matches_wire, F::ONE);
|
||||
} else {
|
||||
out_buffer.set_wire(
|
||||
equality_dummy_wire,
|
||||
(F::from_canonical_usize(i) - F::from_canonical_usize(access_index)).inverse(),
|
||||
);
|
||||
out_buffer.set_wire(index_matches_wire, F::ZERO);
|
||||
}
|
||||
for i in 0..self.gate.bits {
|
||||
let bit = F::from_bool(((access_index >> i) & 1) != 0);
|
||||
set_local_wire(self.gate.wire_bit(i, copy), bit);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -322,6 +340,7 @@ mod tests {
|
||||
/// Returns the local wires for a random access gate given the vectors, elements to compare,
|
||||
/// and indices.
|
||||
fn get_wires(
|
||||
bits: usize,
|
||||
lists: Vec<Vec<F>>,
|
||||
access_indices: Vec<usize>,
|
||||
claimed_elements: Vec<F>,
|
||||
@ -330,8 +349,7 @@ mod tests {
|
||||
let vec_size = lists[0].len();
|
||||
|
||||
let mut v = Vec::new();
|
||||
let mut equality_dummy_vals = Vec::new();
|
||||
let mut index_matches_vals = Vec::new();
|
||||
let mut bit_vals = Vec::new();
|
||||
for copy in 0..num_copies {
|
||||
let access_index = access_indices[copy];
|
||||
v.push(F::from_canonical_usize(access_index));
|
||||
@ -340,26 +358,17 @@ mod tests {
|
||||
v.push(lists[copy][j]);
|
||||
}
|
||||
|
||||
for i in 0..vec_size {
|
||||
if i == access_index {
|
||||
equality_dummy_vals.push(F::ONE);
|
||||
index_matches_vals.push(F::ONE);
|
||||
} else {
|
||||
equality_dummy_vals.push(
|
||||
(F::from_canonical_usize(i) - F::from_canonical_usize(access_index))
|
||||
.inverse(),
|
||||
);
|
||||
index_matches_vals.push(F::ZERO);
|
||||
}
|
||||
for i in 0..bits {
|
||||
bit_vals.push(F::from_bool(((access_index >> i) & 1) != 0));
|
||||
}
|
||||
}
|
||||
v.extend(equality_dummy_vals);
|
||||
v.extend(index_matches_vals);
|
||||
v.extend(bit_vals);
|
||||
|
||||
v.iter().map(|&x| x.into()).collect::<Vec<_>>()
|
||||
v.iter().map(|&x| x.into()).collect()
|
||||
}
|
||||
|
||||
let vec_size = 3;
|
||||
let bits = 3;
|
||||
let vec_size = 1 << bits;
|
||||
let num_copies = 4;
|
||||
let lists = (0..num_copies)
|
||||
.map(|_| F::rand_vec(vec_size))
|
||||
@ -368,7 +377,7 @@ mod tests {
|
||||
.map(|_| thread_rng().gen_range(0..vec_size))
|
||||
.collect::<Vec<_>>();
|
||||
let gate = RandomAccessGate::<F, D> {
|
||||
vec_size,
|
||||
bits,
|
||||
num_copies,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
@ -380,13 +389,18 @@ mod tests {
|
||||
.collect();
|
||||
let good_vars = EvaluationVars {
|
||||
local_constants: &[],
|
||||
local_wires: &get_wires(lists.clone(), access_indices.clone(), good_claimed_elements),
|
||||
local_wires: &get_wires(
|
||||
bits,
|
||||
lists.clone(),
|
||||
access_indices.clone(),
|
||||
good_claimed_elements,
|
||||
),
|
||||
public_inputs_hash: &HashOut::rand(),
|
||||
};
|
||||
let bad_claimed_elements = F::rand_vec(4);
|
||||
let bad_vars = EvaluationVars {
|
||||
local_constants: &[],
|
||||
local_wires: &get_wires(lists, access_indices, bad_claimed_elements),
|
||||
local_wires: &get_wires(bits, lists, access_indices, bad_claimed_elements),
|
||||
public_inputs_hash: &HashOut::rand(),
|
||||
};
|
||||
|
||||
|
||||
222
src/gates/reducing_extension.rs
Normal file
222
src/gates/reducing_extension.rs
Normal file
@ -0,0 +1,222 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::extension_field::FieldExtension;
|
||||
use crate::field::field_types::RichField;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
|
||||
/// Computes `sum alpha^i c_i` for a vector `c_i` of `num_coeffs` elements of the extension field.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ReducingExtensionGate<const D: usize> {
|
||||
pub num_coeffs: usize,
|
||||
}
|
||||
|
||||
impl<const D: usize> ReducingExtensionGate<D> {
|
||||
pub fn new(num_coeffs: usize) -> Self {
|
||||
Self { num_coeffs }
|
||||
}
|
||||
|
||||
pub fn max_coeffs_len(num_wires: usize, num_routed_wires: usize) -> usize {
|
||||
// `3*D` routed wires are used for the output, alpha and old accumulator.
|
||||
// Need `num_coeffs*D` routed wires for coeffs, and `(num_coeffs-1)*D` wires for accumulators.
|
||||
((num_routed_wires - 3 * D) / D).min((num_wires - 2 * D) / (D * 2))
|
||||
}
|
||||
|
||||
pub fn wires_output() -> Range<usize> {
|
||||
0..D
|
||||
}
|
||||
pub fn wires_alpha() -> Range<usize> {
|
||||
D..2 * D
|
||||
}
|
||||
pub fn wires_old_acc() -> Range<usize> {
|
||||
2 * D..3 * D
|
||||
}
|
||||
const START_COEFFS: usize = 3 * D;
|
||||
pub fn wires_coeff(i: usize) -> Range<usize> {
|
||||
Self::START_COEFFS + i * D..Self::START_COEFFS + (i + 1) * D
|
||||
}
|
||||
fn start_accs(&self) -> usize {
|
||||
Self::START_COEFFS + self.num_coeffs * D
|
||||
}
|
||||
fn wires_accs(&self, i: usize) -> Range<usize> {
|
||||
debug_assert!(i < self.num_coeffs);
|
||||
if i == self.num_coeffs - 1 {
|
||||
// The last accumulator is the output.
|
||||
return Self::wires_output();
|
||||
}
|
||||
self.start_accs() + D * i..self.start_accs() + D * (i + 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ReducingExtensionGate<D> {
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}", self)
|
||||
}
|
||||
|
||||
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
|
||||
let alpha = vars.get_local_ext_algebra(Self::wires_alpha());
|
||||
let old_acc = vars.get_local_ext_algebra(Self::wires_old_acc());
|
||||
let coeffs = (0..self.num_coeffs)
|
||||
.map(|i| vars.get_local_ext_algebra(Self::wires_coeff(i)))
|
||||
.collect::<Vec<_>>();
|
||||
let accs = (0..self.num_coeffs)
|
||||
.map(|i| vars.get_local_ext_algebra(self.wires_accs(i)))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut constraints = Vec::with_capacity(<Self as Gate<F, D>>::num_constraints(self));
|
||||
let mut acc = old_acc;
|
||||
for i in 0..self.num_coeffs {
|
||||
constraints.push(acc * alpha + coeffs[i] - accs[i]);
|
||||
acc = accs[i];
|
||||
}
|
||||
|
||||
constraints
|
||||
.into_iter()
|
||||
.flat_map(|alg| alg.to_basefield_array())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let alpha = vars.get_local_ext(Self::wires_alpha());
|
||||
let old_acc = vars.get_local_ext(Self::wires_old_acc());
|
||||
let coeffs = (0..self.num_coeffs)
|
||||
.map(|i| vars.get_local_ext(Self::wires_coeff(i)))
|
||||
.collect::<Vec<_>>();
|
||||
let accs = (0..self.num_coeffs)
|
||||
.map(|i| vars.get_local_ext(self.wires_accs(i)))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut constraints = Vec::with_capacity(<Self as Gate<F, D>>::num_constraints(self));
|
||||
let mut acc = old_acc;
|
||||
for i in 0..self.num_coeffs {
|
||||
constraints.extend((acc * alpha + coeffs[i] - accs[i]).to_basefield_array());
|
||||
acc = accs[i];
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
vars: EvaluationTargets<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let alpha = vars.get_local_ext_algebra(Self::wires_alpha());
|
||||
let old_acc = vars.get_local_ext_algebra(Self::wires_old_acc());
|
||||
let coeffs = (0..self.num_coeffs)
|
||||
.map(|i| vars.get_local_ext_algebra(Self::wires_coeff(i)))
|
||||
.collect::<Vec<_>>();
|
||||
let accs = (0..self.num_coeffs)
|
||||
.map(|i| vars.get_local_ext_algebra(self.wires_accs(i)))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut constraints = Vec::with_capacity(<Self as Gate<F, D>>::num_constraints(self));
|
||||
let mut acc = old_acc;
|
||||
for i in 0..self.num_coeffs {
|
||||
let coeff = coeffs[i];
|
||||
let mut tmp = builder.mul_add_ext_algebra(acc, alpha, coeff);
|
||||
tmp = builder.sub_ext_algebra(tmp, accs[i]);
|
||||
constraints.push(tmp);
|
||||
acc = accs[i];
|
||||
}
|
||||
|
||||
constraints
|
||||
.into_iter()
|
||||
.flat_map(|alg| alg.to_ext_target_array())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn generators(
|
||||
&self,
|
||||
gate_index: usize,
|
||||
_local_constants: &[F],
|
||||
) -> Vec<Box<dyn WitnessGenerator<F>>> {
|
||||
vec![Box::new(
|
||||
ReducingGenerator {
|
||||
gate_index,
|
||||
gate: self.clone(),
|
||||
}
|
||||
.adapter(),
|
||||
)]
|
||||
}
|
||||
|
||||
fn num_wires(&self) -> usize {
|
||||
2 * D + 2 * D * self.num_coeffs
|
||||
}
|
||||
|
||||
fn num_constants(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn degree(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
fn num_constraints(&self) -> usize {
|
||||
D * self.num_coeffs
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ReducingGenerator<const D: usize> {
|
||||
gate_index: usize,
|
||||
gate: ReducingExtensionGate<D>,
|
||||
}
|
||||
|
||||
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ReducingGenerator<D> {
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
ReducingExtensionGate::<D>::wires_alpha()
|
||||
.chain(ReducingExtensionGate::<D>::wires_old_acc())
|
||||
.chain((0..self.gate.num_coeffs).flat_map(ReducingExtensionGate::<D>::wires_coeff))
|
||||
.map(|i| Target::wire(self.gate_index, i))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
let local_extension = |range: Range<usize>| -> F::Extension {
|
||||
let t = ExtensionTarget::from_range(self.gate_index, range);
|
||||
witness.get_extension_target(t)
|
||||
};
|
||||
|
||||
let alpha = local_extension(ReducingExtensionGate::<D>::wires_alpha());
|
||||
let old_acc = local_extension(ReducingExtensionGate::<D>::wires_old_acc());
|
||||
let coeffs = (0..self.gate.num_coeffs)
|
||||
.map(|i| local_extension(ReducingExtensionGate::<D>::wires_coeff(i)))
|
||||
.collect::<Vec<_>>();
|
||||
let accs = (0..self.gate.num_coeffs)
|
||||
.map(|i| ExtensionTarget::from_range(self.gate_index, self.gate.wires_accs(i)))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut acc = old_acc;
|
||||
for i in 0..self.gate.num_coeffs {
|
||||
let computed_acc = acc * alpha + coeffs[i];
|
||||
out_buffer.set_extension_target(accs[i], computed_acc);
|
||||
acc = computed_acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
|
||||
use crate::gates::reducing_extension::ReducingExtensionGate;
|
||||
|
||||
#[test]
|
||||
fn low_degree() {
|
||||
test_low_degree::<GoldilocksField, _, 4>(ReducingExtensionGate::new(22));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_fns() -> Result<()> {
|
||||
test_eval_fns::<GoldilocksField, _, 4>(ReducingExtensionGate::new(22))
|
||||
}
|
||||
}
|
||||
423
src/gates/subtraction_u32.rs
Normal file
423
src/gates/subtraction_u32.rs
Normal file
@ -0,0 +1,423 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::wire::Wire;
|
||||
use crate::iop::witness::{PartitionWitness, Witness};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::circuit_data::CircuitConfig;
|
||||
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
|
||||
/// A gate to perform a subtraction on 32-bit limbs: given `x`, `y`, and `borrow`, it returns
|
||||
/// the result `x - y - borrow` and, if this underflows, a new `borrow`. Inputs are not range-checked.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct U32SubtractionGate<F: RichField + Extendable<D>, const D: usize> {
|
||||
pub num_ops: usize,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> U32SubtractionGate<F, D> {
|
||||
pub fn new_from_config(config: &CircuitConfig) -> Self {
|
||||
Self {
|
||||
num_ops: Self::num_ops(config),
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn num_ops(config: &CircuitConfig) -> usize {
|
||||
let wires_per_op = 5 + Self::num_limbs();
|
||||
let routed_wires_per_op = 5;
|
||||
(config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op)
|
||||
}
|
||||
|
||||
pub fn wire_ith_input_x(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
5 * i
|
||||
}
|
||||
pub fn wire_ith_input_y(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
5 * i + 1
|
||||
}
|
||||
pub fn wire_ith_input_borrow(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
5 * i + 2
|
||||
}
|
||||
|
||||
pub fn wire_ith_output_result(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
5 * i + 3
|
||||
}
|
||||
pub fn wire_ith_output_borrow(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
5 * i + 4
|
||||
}
|
||||
|
||||
pub fn limb_bits() -> usize {
|
||||
2
|
||||
}
|
||||
// We have limbs for the 32 bits of `output_result`.
|
||||
pub fn num_limbs() -> usize {
|
||||
32 / Self::limb_bits()
|
||||
}
|
||||
|
||||
pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
debug_assert!(j < Self::num_limbs());
|
||||
5 * self.num_ops + Self::num_limbs() * i + j
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32SubtractionGate<F, D> {
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}", self)
|
||||
}
|
||||
|
||||
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
for i in 0..self.num_ops {
|
||||
let input_x = vars.local_wires[self.wire_ith_input_x(i)];
|
||||
let input_y = vars.local_wires[self.wire_ith_input_y(i)];
|
||||
let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)];
|
||||
|
||||
let result_initial = input_x - input_y - input_borrow;
|
||||
let base = F::Extension::from_canonical_u64(1 << 32u64);
|
||||
|
||||
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
|
||||
let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)];
|
||||
|
||||
constraints.push(output_result - (result_initial + base * output_borrow));
|
||||
|
||||
// Range-check output_result to be at most 32 bits.
|
||||
let mut combined_limbs = F::Extension::ZERO;
|
||||
let limb_base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits());
|
||||
for j in (0..Self::num_limbs()).rev() {
|
||||
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
|
||||
let max_limb = 1 << Self::limb_bits();
|
||||
let product = (0..max_limb)
|
||||
.map(|x| this_limb - F::Extension::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(product);
|
||||
|
||||
combined_limbs = limb_base * combined_limbs + this_limb;
|
||||
}
|
||||
constraints.push(combined_limbs - output_result);
|
||||
|
||||
// Range-check output_borrow to be one bit.
|
||||
constraints.push(output_borrow * (F::Extension::ONE - output_borrow));
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
for i in 0..self.num_ops {
|
||||
let input_x = vars.local_wires[self.wire_ith_input_x(i)];
|
||||
let input_y = vars.local_wires[self.wire_ith_input_y(i)];
|
||||
let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)];
|
||||
|
||||
let result_initial = input_x - input_y - input_borrow;
|
||||
let base = F::from_canonical_u64(1 << 32u64);
|
||||
|
||||
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
|
||||
let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)];
|
||||
|
||||
constraints.push(output_result - (result_initial + base * output_borrow));
|
||||
|
||||
// Range-check output_result to be at most 32 bits.
|
||||
let mut combined_limbs = F::ZERO;
|
||||
let limb_base = F::from_canonical_u64(1u64 << Self::limb_bits());
|
||||
for j in (0..Self::num_limbs()).rev() {
|
||||
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
|
||||
let max_limb = 1 << Self::limb_bits();
|
||||
let product = (0..max_limb)
|
||||
.map(|x| this_limb - F::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(product);
|
||||
|
||||
combined_limbs = limb_base * combined_limbs + this_limb;
|
||||
}
|
||||
constraints.push(combined_limbs - output_result);
|
||||
|
||||
// Range-check output_borrow to be one bit.
|
||||
constraints.push(output_borrow * (F::ONE - output_borrow));
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
vars: EvaluationTargets<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
for i in 0..self.num_ops {
|
||||
let input_x = vars.local_wires[self.wire_ith_input_x(i)];
|
||||
let input_y = vars.local_wires[self.wire_ith_input_y(i)];
|
||||
let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)];
|
||||
|
||||
let diff = builder.sub_extension(input_x, input_y);
|
||||
let result_initial = builder.sub_extension(diff, input_borrow);
|
||||
let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << 32u64));
|
||||
|
||||
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
|
||||
let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)];
|
||||
|
||||
let computed_output = builder.mul_add_extension(base, output_borrow, result_initial);
|
||||
constraints.push(builder.sub_extension(output_result, computed_output));
|
||||
|
||||
// Range-check output_result to be at most 32 bits.
|
||||
let mut combined_limbs = builder.zero_extension();
|
||||
let limb_base = builder
|
||||
.constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits()));
|
||||
for j in (0..Self::num_limbs()).rev() {
|
||||
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
|
||||
let max_limb = 1 << Self::limb_bits();
|
||||
let mut product = builder.one_extension();
|
||||
for x in 0..max_limb {
|
||||
let x_target =
|
||||
builder.constant_extension(F::Extension::from_canonical_usize(x));
|
||||
let diff = builder.sub_extension(this_limb, x_target);
|
||||
product = builder.mul_extension(product, diff);
|
||||
}
|
||||
constraints.push(product);
|
||||
|
||||
combined_limbs = builder.mul_add_extension(limb_base, combined_limbs, this_limb);
|
||||
}
|
||||
constraints.push(builder.sub_extension(combined_limbs, output_result));
|
||||
|
||||
// Range-check output_borrow to be one bit.
|
||||
let one = builder.one_extension();
|
||||
let not_borrow = builder.sub_extension(one, output_borrow);
|
||||
constraints.push(builder.mul_extension(output_borrow, not_borrow));
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn generators(
|
||||
&self,
|
||||
gate_index: usize,
|
||||
_local_constants: &[F],
|
||||
) -> Vec<Box<dyn WitnessGenerator<F>>> {
|
||||
(0..self.num_ops)
|
||||
.map(|i| {
|
||||
let g: Box<dyn WitnessGenerator<F>> = Box::new(
|
||||
U32SubtractionGenerator {
|
||||
gate: *self,
|
||||
gate_index,
|
||||
i,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
.adapter(),
|
||||
);
|
||||
g
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn num_wires(&self) -> usize {
|
||||
self.num_ops * (5 + Self::num_limbs())
|
||||
}
|
||||
|
||||
fn num_constants(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn degree(&self) -> usize {
|
||||
1 << Self::limb_bits()
|
||||
}
|
||||
|
||||
fn num_constraints(&self) -> usize {
|
||||
self.num_ops * (3 + Self::num_limbs())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct U32SubtractionGenerator<F: RichField + Extendable<D>, const D: usize> {
|
||||
gate: U32SubtractionGate<F, D>,
|
||||
gate_index: usize,
|
||||
i: usize,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
|
||||
for U32SubtractionGenerator<F, D>
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
let local_target = |input| Target::wire(self.gate_index, input);
|
||||
|
||||
vec![
|
||||
local_target(self.gate.wire_ith_input_x(self.i)),
|
||||
local_target(self.gate.wire_ith_input_y(self.i)),
|
||||
local_target(self.gate.wire_ith_input_borrow(self.i)),
|
||||
]
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
let local_wire = |input| Wire {
|
||||
gate: self.gate_index,
|
||||
input,
|
||||
};
|
||||
|
||||
let get_local_wire = |input| witness.get_wire(local_wire(input));
|
||||
|
||||
let input_x = get_local_wire(self.gate.wire_ith_input_x(self.i));
|
||||
let input_y = get_local_wire(self.gate.wire_ith_input_y(self.i));
|
||||
let input_borrow = get_local_wire(self.gate.wire_ith_input_borrow(self.i));
|
||||
|
||||
let result_initial = input_x - input_y - input_borrow;
|
||||
let result_initial_u64 = result_initial.to_canonical_u64();
|
||||
let output_borrow = if result_initial_u64 > 1 << 32u64 {
|
||||
F::ONE
|
||||
} else {
|
||||
F::ZERO
|
||||
};
|
||||
|
||||
let base = F::from_canonical_u64(1 << 32u64);
|
||||
let output_result = result_initial + base * output_borrow;
|
||||
|
||||
let output_result_wire = local_wire(self.gate.wire_ith_output_result(self.i));
|
||||
let output_borrow_wire = local_wire(self.gate.wire_ith_output_borrow(self.i));
|
||||
|
||||
out_buffer.set_wire(output_result_wire, output_result);
|
||||
out_buffer.set_wire(output_borrow_wire, output_borrow);
|
||||
|
||||
let output_result_u64 = output_result.to_canonical_u64();
|
||||
|
||||
let num_limbs = U32SubtractionGate::<F, D>::num_limbs();
|
||||
let limb_base = 1 << U32SubtractionGate::<F, D>::limb_bits();
|
||||
let output_limbs: Vec<_> = (0..num_limbs)
|
||||
.scan(output_result_u64, |acc, _| {
|
||||
let tmp = *acc % limb_base;
|
||||
*acc /= limb_base;
|
||||
Some(F::from_canonical_u64(tmp))
|
||||
})
|
||||
.collect();
|
||||
|
||||
for j in 0..num_limbs {
|
||||
let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j));
|
||||
out_buffer.set_wire(wire, output_limbs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use anyhow::Result;
|
||||
use rand::Rng;
|
||||
|
||||
use crate::field::extension_field::quartic::QuarticExtension;
|
||||
use crate::field::field_types::{Field, PrimeField};
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::gates::gate::Gate;
|
||||
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
|
||||
use crate::gates::subtraction_u32::U32SubtractionGate;
|
||||
use crate::hash::hash_types::HashOut;
|
||||
use crate::plonk::vars::EvaluationVars;
|
||||
|
||||
#[test]
|
||||
fn low_degree() {
|
||||
test_low_degree::<GoldilocksField, _, 4>(U32SubtractionGate::<GoldilocksField, 4> {
|
||||
num_ops: 3,
|
||||
_phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_fns() -> Result<()> {
|
||||
test_eval_fns::<GoldilocksField, _, 4>(U32SubtractionGate::<GoldilocksField, 4> {
|
||||
num_ops: 3,
|
||||
_phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gate_constraint() {
|
||||
type F = GoldilocksField;
|
||||
type FF = QuarticExtension<GoldilocksField>;
|
||||
const D: usize = 4;
|
||||
const NUM_U32_SUBTRACTION_OPS: usize = 3;
|
||||
|
||||
fn get_wires(inputs_x: Vec<u64>, inputs_y: Vec<u64>, borrows: Vec<u64>) -> Vec<FF> {
|
||||
let mut v0 = Vec::new();
|
||||
let mut v1 = Vec::new();
|
||||
|
||||
let limb_bits = U32SubtractionGate::<F, D>::limb_bits();
|
||||
let num_limbs = U32SubtractionGate::<F, D>::num_limbs();
|
||||
let limb_base = 1 << limb_bits;
|
||||
for c in 0..NUM_U32_SUBTRACTION_OPS {
|
||||
let input_x = F::from_canonical_u64(inputs_x[c]);
|
||||
let input_y = F::from_canonical_u64(inputs_y[c]);
|
||||
let input_borrow = F::from_canonical_u64(borrows[c]);
|
||||
|
||||
let result_initial = input_x - input_y - input_borrow;
|
||||
let result_initial_u64 = result_initial.to_canonical_u64();
|
||||
let output_borrow = if result_initial_u64 > 1 << 32u64 {
|
||||
F::ONE
|
||||
} else {
|
||||
F::ZERO
|
||||
};
|
||||
|
||||
let base = F::from_canonical_u64(1 << 32u64);
|
||||
let output_result = result_initial + base * output_borrow;
|
||||
|
||||
let output_result_u64 = output_result.to_canonical_u64();
|
||||
|
||||
let mut output_limbs: Vec<_> = (0..num_limbs)
|
||||
.scan(output_result_u64, |acc, _| {
|
||||
let tmp = *acc % limb_base;
|
||||
*acc /= limb_base;
|
||||
Some(F::from_canonical_u64(tmp))
|
||||
})
|
||||
.collect();
|
||||
|
||||
v0.push(input_x);
|
||||
v0.push(input_y);
|
||||
v0.push(input_borrow);
|
||||
v0.push(output_result);
|
||||
v0.push(output_borrow);
|
||||
v1.append(&mut output_limbs);
|
||||
}
|
||||
|
||||
v0.iter()
|
||||
.chain(v1.iter())
|
||||
.map(|&x| x.into())
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let inputs_x = (0..NUM_U32_SUBTRACTION_OPS)
|
||||
.map(|_| rng.gen::<u32>() as u64)
|
||||
.collect();
|
||||
let inputs_y = (0..NUM_U32_SUBTRACTION_OPS)
|
||||
.map(|_| rng.gen::<u32>() as u64)
|
||||
.collect();
|
||||
let borrows = (0..NUM_U32_SUBTRACTION_OPS)
|
||||
.map(|_| (rng.gen::<u32>() % 2) as u64)
|
||||
.collect();
|
||||
|
||||
let gate = U32SubtractionGate::<F, D> {
|
||||
num_ops: NUM_U32_SUBTRACTION_OPS,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
|
||||
let vars = EvaluationVars {
|
||||
local_constants: &[],
|
||||
local_wires: &get_wires(inputs_x, inputs_y, borrows),
|
||||
public_inputs_hash: &HashOut::rand(),
|
||||
};
|
||||
|
||||
assert!(
|
||||
gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()),
|
||||
"Gate constraints are not satisfied."
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -1,5 +1,7 @@
|
||||
#![allow(clippy::assertions_on_constants)]
|
||||
|
||||
use std::arch::aarch64::*;
|
||||
use std::convert::TryInto;
|
||||
use std::arch::asm;
|
||||
|
||||
use static_assertions::const_assert;
|
||||
use unroll::unroll_for_loops;
|
||||
@ -172,9 +174,7 @@ unsafe fn multiply(x: u64, y: u64) -> u64 {
|
||||
let xy_hi_lo_mul_epsilon = mul_epsilon(xy_hi);
|
||||
|
||||
// add_with_wraparound is safe, as xy_hi_lo_mul_epsilon <= 0xfffffffe00000001 <= ORDER.
|
||||
let res1 = add_with_wraparound(res0, xy_hi_lo_mul_epsilon);
|
||||
|
||||
res1
|
||||
add_with_wraparound(res0, xy_hi_lo_mul_epsilon)
|
||||
}
|
||||
|
||||
// ==================================== STANDALONE CONST LAYER =====================================
|
||||
@ -267,9 +267,7 @@ unsafe fn mds_reduce(
|
||||
// Multiply by EPSILON and accumulate.
|
||||
let res_unadj = vmlal_laneq_u32::<0>(res_lo, res_hi_hi, mds_consts0);
|
||||
let res_adj = vcgtq_u64(res_lo, res_unadj);
|
||||
let res = vsraq_n_u64::<32>(res_unadj, res_adj);
|
||||
|
||||
res
|
||||
vsraq_n_u64::<32>(res_unadj, res_adj)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
@ -969,8 +967,7 @@ unsafe fn partial_round(
|
||||
#[inline(always)]
|
||||
unsafe fn full_round(state: [u64; 12], round_constants: &[u64; WIDTH]) -> [u64; 12] {
|
||||
let state = sbox_layer_full(state);
|
||||
let state = mds_const_layers_full(state, round_constants);
|
||||
state
|
||||
mds_const_layers_full(state, round_constants)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
use core::arch::x86_64::*;
|
||||
use std::convert::TryInto;
|
||||
use std::mem::size_of;
|
||||
|
||||
use static_assertions::const_assert;
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
use std::convert::TryInto;
|
||||
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
//! Concrete instantiation of a hash function.
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
use crate::field::extension_field::Extendable;
|
||||
use crate::field::field_types::RichField;
|
||||
use crate::hash::hash_types::{HashOut, HashOutTarget};
|
||||
@ -131,10 +129,8 @@ pub fn hash_n_to_m<F: RichField, P: PlonkyPermutation<F>>(
|
||||
|
||||
// Absorb all input chunks.
|
||||
for input_chunk in inputs.chunks(SPONGE_RATE) {
|
||||
for i in 0..input_chunk.len() {
|
||||
state[i] = input_chunk[i];
|
||||
}
|
||||
state = P::permute(state);
|
||||
state[..input_chunk.len()].copy_from_slice(input_chunk);
|
||||
state = permute(state);
|
||||
}
|
||||
|
||||
// Squeeze until we have the desired number of outputs.
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
use std::convert::TryInto;
|
||||
|
||||
use anyhow::{ensure, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@ -55,6 +53,7 @@ pub(crate) fn verify_merkle_proof<F: RichField, H: Hasher<F>>(
|
||||
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
/// Verifies that the given leaf data is present at the given index in the Merkle tree with the
|
||||
/// given cap. The index is given by it's little-endian bits.
|
||||
#[cfg(test)]
|
||||
pub(crate) fn verify_merkle_proof<H: AlgebraicHasher<F>>(
|
||||
&mut self,
|
||||
leaf_data: Vec<Target>,
|
||||
@ -94,7 +93,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
proof: &MerkleProofTarget,
|
||||
) {
|
||||
let zero = self.zero();
|
||||
let mut state: HashOutTarget = self.hash_or_noop::<H>(leaf_data);
|
||||
let mut state:HashOutTarget = self.hash_or_noop(leaf_data);
|
||||
|
||||
for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) {
|
||||
let mut perm_inputs = [zero; SPONGE_WIDTH];
|
||||
@ -116,7 +115,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assert_hashes_equal(&mut self, x: HashOutTarget, y: HashOutTarget) {
|
||||
pub fn connect_hashes(&mut self, x: HashOutTarget, y: HashOutTarget) {
|
||||
for i in 0..4 {
|
||||
self.connect(x.elements[i], y.elements[i]);
|
||||
}
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
//! Implementation of the Poseidon hash function, as described in
|
||||
//! https://eprint.iacr.org/2019/458.pdf
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
use unroll::unroll_for_loops;
|
||||
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
@ -452,9 +450,10 @@ pub trait Poseidon: PrimeField {
|
||||
s0,
|
||||
);
|
||||
for i in 1..WIDTH {
|
||||
let t = <Self as Poseidon>::FAST_PARTIAL_ROUND_W_HATS[r][i - 1];
|
||||
let t = Self::from_canonical_u64(t);
|
||||
d = builder.mul_const_add_extension(t, state[i], d);
|
||||
let t = <Self as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_W_HATS[r][i - 1];
|
||||
let t = Self::Extension::from_canonical_u64(t);
|
||||
let t = builder.constant_extension(t);
|
||||
d = builder.mul_add_extension(t, state[i], d);
|
||||
}
|
||||
|
||||
let mut result = [builder.zero_extension(); WIDTH];
|
||||
@ -624,6 +623,7 @@ pub trait Poseidon: PrimeField {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod test_helpers {
|
||||
use crate::field::field_types::Field;
|
||||
use crate::hash::hashing::SPONGE_WIDTH;
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use std::convert::TryInto;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::{Extendable, FieldExtension};
|
||||
use crate::field::field_types::RichField;
|
||||
use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget};
|
||||
|
||||
@ -1,9 +1,15 @@
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use num::BigUint;
|
||||
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::{Extendable, FieldExtension};
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::gadgets::arithmetic_u32::U32Target;
|
||||
use crate::gadgets::biguint::BigUintTarget;
|
||||
use crate::gadgets::nonnative::NonNativeTarget;
|
||||
use crate::hash::hash_types::{HashOut, HashOutTarget};
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::wire::Wire;
|
||||
@ -89,7 +95,7 @@ pub(crate) fn generate_partial_witness<
|
||||
assert_eq!(
|
||||
remaining_generators, 0,
|
||||
"{} generators weren't run",
|
||||
remaining_generators
|
||||
remaining_generators,
|
||||
);
|
||||
|
||||
witness
|
||||
@ -156,6 +162,24 @@ impl<F: Field> GeneratedValues<F> {
|
||||
self.target_values.push((target, value))
|
||||
}
|
||||
|
||||
fn set_u32_target(&mut self, target: U32Target, value: u32) {
|
||||
self.set_target(target.0, F::from_canonical_u32(value))
|
||||
}
|
||||
|
||||
pub fn set_biguint_target(&mut self, target: BigUintTarget, value: BigUint) {
|
||||
let mut limbs = value.to_u32_digits();
|
||||
assert!(target.num_limbs() >= limbs.len());
|
||||
|
||||
limbs.resize(target.num_limbs(), 0);
|
||||
for i in 0..target.num_limbs() {
|
||||
self.set_u32_target(target.get_limb(i), limbs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_nonnative_target<FF: Field>(&mut self, target: NonNativeTarget<FF>, value: FF) {
|
||||
self.set_biguint_target(target.value, value.to_biguint())
|
||||
}
|
||||
|
||||
pub fn set_hash_target(&mut self, ht: HashOutTarget, value: HashOut<F>) {
|
||||
ht.elements
|
||||
.iter()
|
||||
|
||||
@ -41,6 +41,7 @@ impl Target {
|
||||
|
||||
/// A `Target` which has already been constrained such that it can only be 0 or 1.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
#[allow(clippy::manual_non_exhaustive)]
|
||||
pub struct BoolTarget {
|
||||
pub target: Target,
|
||||
/// This private field is here to force all instantiations to go through `new_unsafe`.
|
||||
|
||||
@ -1,9 +1,14 @@
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryInto;
|
||||
|
||||
use num::{BigUint, FromPrimitive, Zero};
|
||||
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::{Extendable, FieldExtension};
|
||||
use crate::field::field_types::{Field, RichField};
|
||||
use crate::field::field_types::Field;
|
||||
use crate::gadgets::arithmetic_u32::U32Target;
|
||||
use crate::gadgets::biguint::BigUintTarget;
|
||||
use crate::gadgets::nonnative::NonNativeTarget;
|
||||
use crate::hash::hash_types::HashOutTarget;
|
||||
use crate::hash::hash_types::{HashOut, MerkleCapTarget};
|
||||
use crate::hash::merkle_tree::MerkleCap;
|
||||
@ -54,6 +59,24 @@ pub trait Witness<F: Field> {
|
||||
panic!("not a bool")
|
||||
}
|
||||
|
||||
fn get_biguint_target(&self, target: BigUintTarget) -> BigUint {
|
||||
let mut result = BigUint::zero();
|
||||
|
||||
let limb_base = BigUint::from_u64(1 << 32u64).unwrap();
|
||||
for i in (0..target.num_limbs()).rev() {
|
||||
let limb = target.get_limb(i);
|
||||
result *= &limb_base;
|
||||
result += self.get_target(limb.0).to_biguint();
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn get_nonnative_target<FF: Field>(&self, target: NonNativeTarget<FF>) -> FF {
|
||||
let val = self.get_biguint_target(target.value);
|
||||
FF::from_biguint(val)
|
||||
}
|
||||
|
||||
fn get_hash_target(&self, ht: HashOutTarget) -> HashOut<F> {
|
||||
HashOut {
|
||||
elements: self.get_targets(&ht.elements).try_into().unwrap(),
|
||||
@ -122,6 +145,16 @@ pub trait Witness<F: Field> {
|
||||
self.set_target(target.target, F::from_bool(value))
|
||||
}
|
||||
|
||||
fn set_u32_target(&mut self, target: U32Target, value: u32) {
|
||||
self.set_target(target.0, F::from_canonical_u32(value))
|
||||
}
|
||||
|
||||
fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) {
|
||||
for (<, &l) in target.limbs.iter().zip(&value.to_u32_digits()) {
|
||||
self.set_u32_target(lt, l);
|
||||
}
|
||||
}
|
||||
|
||||
fn set_wire(&mut self, wire: Wire, value: F) {
|
||||
self.set_target(Target::Wire(wire), value)
|
||||
}
|
||||
|
||||
18
src/lib.rs
18
src/lib.rs
@ -1,9 +1,15 @@
|
||||
#![feature(asm)]
|
||||
#![feature(destructuring_assignment)]
|
||||
#![allow(incomplete_features)]
|
||||
#![allow(const_evaluatable_unchecked)]
|
||||
#![allow(clippy::new_without_default)]
|
||||
#![allow(clippy::too_many_arguments)]
|
||||
#![allow(clippy::len_without_is_empty)]
|
||||
#![allow(clippy::needless_range_loop)]
|
||||
#![feature(asm_sym)]
|
||||
#![feature(generic_const_exprs)]
|
||||
#![feature(specialization)]
|
||||
#![feature(stdsimd)]
|
||||
|
||||
pub mod curve;
|
||||
pub mod field;
|
||||
pub mod fri;
|
||||
pub mod gadgets;
|
||||
@ -13,3 +19,11 @@ pub mod iop;
|
||||
pub mod plonk;
|
||||
pub mod polynomial;
|
||||
pub mod util;
|
||||
|
||||
// Set up Jemalloc
|
||||
#[cfg(not(target_env = "msvc"))]
|
||||
use jemallocator::Jemalloc;
|
||||
|
||||
#[cfg(not(target_env = "msvc"))]
|
||||
#[global_allocator]
|
||||
static GLOBAL: Jemalloc = Jemalloc;
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
use std::cmp::max;
|
||||
use std::collections::{BTreeMap, HashMap, HashSet};
|
||||
use std::convert::TryInto;
|
||||
use std::time::Instant;
|
||||
|
||||
use log::{debug, info, Level};
|
||||
@ -12,14 +11,20 @@ use crate::field::fft::fft_root_table;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::fri::commitment::PolynomialBatchCommitment;
|
||||
use crate::fri::{FriConfig, FriParams};
|
||||
use crate::gadgets::arithmetic_extension::ArithmeticOperation;
|
||||
use crate::gates::arithmetic::ArithmeticExtensionGate;
|
||||
use crate::gadgets::arithmetic::BaseArithmeticOperation;
|
||||
use crate::gadgets::arithmetic_extension::ExtensionArithmeticOperation;
|
||||
use crate::gadgets::arithmetic_u32::U32Target;
|
||||
use crate::gates::arithmetic_base::ArithmeticGate;
|
||||
use crate::gates::arithmetic_extension::ArithmeticExtensionGate;
|
||||
use crate::gates::arithmetic_u32::U32ArithmeticGate;
|
||||
use crate::gates::constant::ConstantGate;
|
||||
use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate};
|
||||
use crate::gates::gate_tree::Tree;
|
||||
use crate::gates::multiplication_extension::MulExtensionGate;
|
||||
use crate::gates::noop::NoopGate;
|
||||
use crate::gates::public_input::PublicInputGate;
|
||||
use crate::gates::random_access::RandomAccessGate;
|
||||
use crate::gates::subtraction_u32::U32SubtractionGate;
|
||||
use crate::gates::switch::SwitchGate;
|
||||
use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget};
|
||||
use crate::iop::generator::{
|
||||
@ -35,7 +40,7 @@ use crate::plonk::config::{GenericConfig, Hasher};
|
||||
use crate::plonk::copy_constraint::CopyConstraint;
|
||||
use crate::plonk::permutation_argument::Forest;
|
||||
use crate::plonk::plonk_common::PlonkPolynomials;
|
||||
use crate::polynomial::polynomial::PolynomialValues;
|
||||
use crate::polynomial::PolynomialValues;
|
||||
use crate::util::context_tree::ContextTree;
|
||||
use crate::util::marking::{Markable, MarkedTargets};
|
||||
use crate::util::partial_products::num_partial_products;
|
||||
@ -71,24 +76,13 @@ pub struct CircuitBuilder<F: Extendable<D>, const D: usize> {
|
||||
constants_to_targets: HashMap<F, Target>,
|
||||
targets_to_constants: HashMap<Target, F>,
|
||||
|
||||
/// Memoized results of `arithmetic` calls.
|
||||
pub(crate) base_arithmetic_results: HashMap<BaseArithmeticOperation<F>, Target>,
|
||||
|
||||
/// Memoized results of `arithmetic_extension` calls.
|
||||
pub(crate) arithmetic_results: HashMap<ArithmeticOperation<F, D>, ExtensionTarget<D>>,
|
||||
pub(crate) arithmetic_results: HashMap<ExtensionArithmeticOperation<F, D>, ExtensionTarget<D>>,
|
||||
|
||||
/// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using
|
||||
/// these constants with gate index `g` and already using `i` arithmetic operations.
|
||||
pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>,
|
||||
|
||||
/// A map `(c0, c1) -> (g, i)` from constants `vec_size` to an available arithmetic gate using
|
||||
/// these constants with gate index `g` and already using `i` random accesses.
|
||||
pub(crate) free_random_access: HashMap<usize, (usize, usize)>,
|
||||
|
||||
// `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value
|
||||
// chunk_size, and contains `(g, i, c)`, if the gate `g`, at index `i`, already contains `c` copies
|
||||
// of switches
|
||||
pub(crate) current_switch_gates: Vec<Option<(SwitchGate<F, D>, usize, usize)>>,
|
||||
|
||||
/// An available `ConstantGate` instance, if any.
|
||||
free_constant: Option<(usize, usize)>,
|
||||
batched_gates: BatchedGates<F, D>,
|
||||
}
|
||||
|
||||
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
@ -104,12 +98,10 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
marked_targets: Vec::new(),
|
||||
generators: Vec::new(),
|
||||
constants_to_targets: HashMap::new(),
|
||||
base_arithmetic_results: HashMap::new(),
|
||||
arithmetic_results: HashMap::new(),
|
||||
targets_to_constants: HashMap::new(),
|
||||
free_arithmetic: HashMap::new(),
|
||||
free_random_access: HashMap::new(),
|
||||
current_switch_gates: Vec::new(),
|
||||
free_constant: None,
|
||||
batched_gates: BatchedGates::new(),
|
||||
};
|
||||
builder.check_config();
|
||||
builder
|
||||
@ -216,6 +208,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
gate_ref,
|
||||
constants,
|
||||
});
|
||||
|
||||
index
|
||||
}
|
||||
|
||||
@ -260,6 +253,11 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
self.connect(x, zero);
|
||||
}
|
||||
|
||||
pub fn assert_one(&mut self, x: Target) {
|
||||
let one = self.one();
|
||||
self.connect(x, one);
|
||||
}
|
||||
|
||||
pub fn add_generators(&mut self, generators: Vec<Box<dyn WitnessGenerator<F>>>) {
|
||||
self.generators.extend(generators);
|
||||
}
|
||||
@ -313,26 +311,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
target
|
||||
}
|
||||
|
||||
/// Returns the gate index and copy index of a free `ConstantGate` slot, potentially adding a
|
||||
/// new `ConstantGate` if needed.
|
||||
fn constant_gate_instance(&mut self) -> (usize, usize) {
|
||||
if self.free_constant.is_none() {
|
||||
let num_consts = self.config.constant_gate_size;
|
||||
// We will fill this `ConstantGate` with zero constants initially.
|
||||
// These will be overwritten by `constant` as the gate instances are filled.
|
||||
let gate = self.add_gate(ConstantGate { num_consts }, vec![F::ZERO; num_consts]);
|
||||
self.free_constant = Some((gate, 0));
|
||||
}
|
||||
|
||||
let (gate, instance) = self.free_constant.unwrap();
|
||||
if instance + 1 < self.config.constant_gate_size {
|
||||
self.free_constant = Some((gate, instance + 1));
|
||||
} else {
|
||||
self.free_constant = None;
|
||||
}
|
||||
(gate, instance)
|
||||
}
|
||||
|
||||
pub fn constants(&mut self, constants: &[F]) -> Vec<Target> {
|
||||
constants.iter().map(|&c| self.constant(c)).collect()
|
||||
}
|
||||
@ -345,6 +323,11 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a U32Target for the value `c`, which is assumed to be at most 32 bits.
|
||||
pub fn constant_u32(&mut self, c: u32) -> U32Target {
|
||||
U32Target(self.constant(F::from_canonical_u32(c)))
|
||||
}
|
||||
|
||||
/// If the given target is a constant (i.e. it was created by the `constant(F)` method), returns
|
||||
/// its constant value. Otherwise, returns `None`.
|
||||
pub fn target_as_constant(&self, target: Target) -> Option<F> {
|
||||
@ -396,6 +379,20 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
}
|
||||
}
|
||||
|
||||
/// The number of (base field) `arithmetic` operations that can be performed in a single gate.
|
||||
pub(crate) fn num_base_arithmetic_ops_per_gate(&self) -> usize {
|
||||
if self.config.use_base_arithmetic_gate {
|
||||
ArithmeticGate::new_from_config(&self.config).num_ops
|
||||
} else {
|
||||
self.num_ext_arithmetic_ops_per_gate()
|
||||
}
|
||||
}
|
||||
|
||||
/// The number of `arithmetic_extension` operations that can be performed in a single gate.
|
||||
pub(crate) fn num_ext_arithmetic_ops_per_gate(&self) -> usize {
|
||||
ArithmeticExtensionGate::<D>::new_from_config(&self.config).num_ops
|
||||
}
|
||||
|
||||
/// The number of polynomial values that will be revealed per opening, both for the "regular"
|
||||
/// polynomials and for the Z polynomials. Because calculating these values involves a recursive
|
||||
/// dependence (the amount of blinding depends on the degree, which depends on the blinding),
|
||||
@ -566,76 +563,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
)
|
||||
}
|
||||
|
||||
/// Fill the remaining unused arithmetic operations with zeros, so that all
|
||||
/// `ArithmeticExtensionGenerator` are run.
|
||||
fn fill_arithmetic_gates(&mut self) {
|
||||
let zero = self.zero_extension();
|
||||
let remaining_arithmetic_gates = self.free_arithmetic.values().copied().collect::<Vec<_>>();
|
||||
for (gate, i) in remaining_arithmetic_gates {
|
||||
for j in i..ArithmeticExtensionGate::<D>::num_ops(&self.config) {
|
||||
let wires_multiplicand_0 = ExtensionTarget::from_range(
|
||||
gate,
|
||||
ArithmeticExtensionGate::<D>::wires_ith_multiplicand_0(j),
|
||||
);
|
||||
let wires_multiplicand_1 = ExtensionTarget::from_range(
|
||||
gate,
|
||||
ArithmeticExtensionGate::<D>::wires_ith_multiplicand_1(j),
|
||||
);
|
||||
let wires_addend = ExtensionTarget::from_range(
|
||||
gate,
|
||||
ArithmeticExtensionGate::<D>::wires_ith_addend(j),
|
||||
);
|
||||
|
||||
self.connect_extension(zero, wires_multiplicand_0);
|
||||
self.connect_extension(zero, wires_multiplicand_1);
|
||||
self.connect_extension(zero, wires_addend);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fill the remaining unused random access operations with zeros, so that all
|
||||
/// `RandomAccessGenerator`s are run.
|
||||
fn fill_random_access_gates(&mut self) {
|
||||
let zero = self.zero();
|
||||
for (vec_size, (_, i)) in self.free_random_access.clone() {
|
||||
let max_copies = RandomAccessGate::<F, D>::max_num_copies(
|
||||
self.config.num_routed_wires,
|
||||
self.config.num_wires,
|
||||
vec_size,
|
||||
);
|
||||
for _ in i..max_copies {
|
||||
self.random_access(zero, zero, vec![zero; vec_size]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fill the remaining unused switch gates with dummy values, so that all
|
||||
/// `SwitchGenerator` are run.
|
||||
fn fill_switch_gates(&mut self) {
|
||||
let zero = self.zero();
|
||||
|
||||
for chunk_size in 1..=self.current_switch_gates.len() {
|
||||
if let Some((gate, gate_index, mut copy)) =
|
||||
self.current_switch_gates[chunk_size - 1].clone()
|
||||
{
|
||||
while copy < gate.num_copies {
|
||||
for element in 0..chunk_size {
|
||||
let wire_first_input =
|
||||
Target::wire(gate_index, gate.wire_first_input(copy, element));
|
||||
let wire_second_input =
|
||||
Target::wire(gate_index, gate.wire_second_input(copy, element));
|
||||
let wire_switch_bool =
|
||||
Target::wire(gate_index, gate.wire_switch_bool(copy));
|
||||
self.connect(zero, wire_first_input);
|
||||
self.connect(zero, wire_second_input);
|
||||
self.connect(zero, wire_switch_bool);
|
||||
}
|
||||
copy += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn print_gate_counts(&self, min_delta: usize) {
|
||||
// Print gate counts for each context.
|
||||
self.context_log
|
||||
@ -659,9 +586,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
let mut timing = TimingTree::new("preprocess", Level::Trace);
|
||||
let start = Instant::now();
|
||||
|
||||
self.fill_arithmetic_gates();
|
||||
self.fill_random_access_gates();
|
||||
self.fill_switch_gates();
|
||||
self.fill_batched_gates();
|
||||
|
||||
// Hash the public inputs, and route them to a `PublicInputGate` which will enforce that
|
||||
// those hash wires match the claimed public inputs.
|
||||
@ -698,7 +623,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
..=1 << self.config.rate_bits)
|
||||
.min_by_key(|&q| num_partial_products(self.config.num_routed_wires, q).0 + q)
|
||||
.unwrap();
|
||||
info!("Quotient degree factor set to: {}.", quotient_degree_factor);
|
||||
debug!("Quotient degree factor set to: {}.", quotient_degree_factor);
|
||||
let prefixed_gates = PrefixedGate::from_tree(gate_tree);
|
||||
|
||||
let subgroup = F::two_adic_subgroup(degree_bits);
|
||||
@ -710,7 +635,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
|
||||
// Precompute FFT roots.
|
||||
let max_fft_points =
|
||||
1 << degree_bits + max(self.config.rate_bits, log2_ceil(quotient_degree_factor));
|
||||
1 << (degree_bits + max(self.config.rate_bits, log2_ceil(quotient_degree_factor)));
|
||||
let fft_root_table = fft_root_table(max_fft_points);
|
||||
|
||||
let constants_sigmas_vecs = [constant_vecs, sigma_vecs.clone()].concat();
|
||||
@ -745,7 +670,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
let watch_rep_index = forest.parents[watch_index];
|
||||
generator_indices_by_watches
|
||||
.entry(watch_rep_index)
|
||||
.or_insert(vec![])
|
||||
.or_insert_with(Vec::new)
|
||||
.push(i);
|
||||
}
|
||||
}
|
||||
@ -801,7 +726,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
circuit_digest,
|
||||
};
|
||||
|
||||
info!("Building circuit took {}s", start.elapsed().as_secs_f32());
|
||||
debug!("Building circuit took {}s", start.elapsed().as_secs_f32());
|
||||
CircuitData {
|
||||
prover_only,
|
||||
verifier_only,
|
||||
@ -837,3 +762,386 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Various gate types can contain multiple copies in a single Gate. This helper struct lets a
|
||||
/// CircuitBuilder track such gates that are currently being "filled up."
|
||||
pub struct BatchedGates<F: RichField + Extendable<D>, const D: usize> {
|
||||
/// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using
|
||||
/// these constants with gate index `g` and already using `i` arithmetic operations.
|
||||
pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>,
|
||||
pub(crate) free_base_arithmetic: HashMap<(F, F), (usize, usize)>,
|
||||
|
||||
pub(crate) free_mul: HashMap<F, (usize, usize)>,
|
||||
|
||||
/// A map `b -> (g, i)` from `b` bits to an available random access gate of that size with gate
|
||||
/// index `g` and already using `i` random accesses.
|
||||
pub(crate) free_random_access: HashMap<usize, (usize, usize)>,
|
||||
|
||||
/// `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value
|
||||
/// chunk_size, and contains `(g, i, c)`, if the gate `g`, at index `i`, already contains `c` copies
|
||||
/// of switches
|
||||
pub(crate) current_switch_gates: Vec<Option<(SwitchGate<F, D>, usize, usize)>>,
|
||||
|
||||
/// The `U32ArithmeticGate` currently being filled (so new u32 arithmetic operations will be added to this gate before creating a new one)
|
||||
pub(crate) current_u32_arithmetic_gate: Option<(usize, usize)>,
|
||||
|
||||
/// The `U32SubtractionGate` currently being filled (so new u32 subtraction operations will be added to this gate before creating a new one)
|
||||
pub(crate) current_u32_subtraction_gate: Option<(usize, usize)>,
|
||||
|
||||
/// An available `ConstantGate` instance, if any.
|
||||
pub(crate) free_constant: Option<(usize, usize)>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> BatchedGates<F, D> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
free_arithmetic: HashMap::new(),
|
||||
free_base_arithmetic: HashMap::new(),
|
||||
free_mul: HashMap::new(),
|
||||
free_random_access: HashMap::new(),
|
||||
current_switch_gates: Vec::new(),
|
||||
current_u32_arithmetic_gate: None,
|
||||
current_u32_subtraction_gate: None,
|
||||
free_constant: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
/// Finds the last available arithmetic gate with the given constants or add one if there aren't any.
|
||||
/// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index
|
||||
/// `g` and the gate's `i`-th operation is available.
|
||||
pub(crate) fn find_base_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) {
|
||||
let (gate, i) = self
|
||||
.batched_gates
|
||||
.free_base_arithmetic
|
||||
.get(&(const_0, const_1))
|
||||
.copied()
|
||||
.unwrap_or_else(|| {
|
||||
let gate = self.add_gate(
|
||||
ArithmeticGate::new_from_config(&self.config),
|
||||
vec![const_0, const_1],
|
||||
);
|
||||
(gate, 0)
|
||||
});
|
||||
|
||||
// Update `free_arithmetic` with new values.
|
||||
if i < ArithmeticGate::num_ops(&self.config) - 1 {
|
||||
self.batched_gates
|
||||
.free_base_arithmetic
|
||||
.insert((const_0, const_1), (gate, i + 1));
|
||||
} else {
|
||||
self.batched_gates
|
||||
.free_base_arithmetic
|
||||
.remove(&(const_0, const_1));
|
||||
}
|
||||
|
||||
(gate, i)
|
||||
}
|
||||
|
||||
/// Finds the last available arithmetic gate with the given constants or add one if there aren't any.
|
||||
/// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index
|
||||
/// `g` and the gate's `i`-th operation is available.
|
||||
pub(crate) fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) {
|
||||
let (gate, i) = self
|
||||
.batched_gates
|
||||
.free_arithmetic
|
||||
.get(&(const_0, const_1))
|
||||
.copied()
|
||||
.unwrap_or_else(|| {
|
||||
let gate = self.add_gate(
|
||||
ArithmeticExtensionGate::new_from_config(&self.config),
|
||||
vec![const_0, const_1],
|
||||
);
|
||||
(gate, 0)
|
||||
});
|
||||
|
||||
// Update `free_arithmetic` with new values.
|
||||
if i < ArithmeticExtensionGate::<D>::num_ops(&self.config) - 1 {
|
||||
self.batched_gates
|
||||
.free_arithmetic
|
||||
.insert((const_0, const_1), (gate, i + 1));
|
||||
} else {
|
||||
self.batched_gates
|
||||
.free_arithmetic
|
||||
.remove(&(const_0, const_1));
|
||||
}
|
||||
|
||||
(gate, i)
|
||||
}
|
||||
|
||||
/// Finds the last available arithmetic gate with the given constants or add one if there aren't any.
|
||||
/// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index
|
||||
/// `g` and the gate's `i`-th operation is available.
|
||||
pub(crate) fn find_mul_gate(&mut self, const_0: F) -> (usize, usize) {
|
||||
let (gate, i) = self
|
||||
.batched_gates
|
||||
.free_mul
|
||||
.get(&const_0)
|
||||
.copied()
|
||||
.unwrap_or_else(|| {
|
||||
let gate = self.add_gate(
|
||||
MulExtensionGate::new_from_config(&self.config),
|
||||
vec![const_0],
|
||||
);
|
||||
(gate, 0)
|
||||
});
|
||||
|
||||
// Update `free_arithmetic` with new values.
|
||||
if i < MulExtensionGate::<D>::num_ops(&self.config) - 1 {
|
||||
self.batched_gates.free_mul.insert(const_0, (gate, i + 1));
|
||||
} else {
|
||||
self.batched_gates.free_mul.remove(&const_0);
|
||||
}
|
||||
|
||||
(gate, i)
|
||||
}
|
||||
|
||||
/// Finds the last available random access gate with the given `vec_size` or add one if there aren't any.
|
||||
/// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index
|
||||
/// `g` and the gate's `i`-th random access is available.
|
||||
pub(crate) fn find_random_access_gate(&mut self, bits: usize) -> (usize, usize) {
|
||||
let (gate, i) = self
|
||||
.batched_gates
|
||||
.free_random_access
|
||||
.get(&bits)
|
||||
.copied()
|
||||
.unwrap_or_else(|| {
|
||||
let gate = self.add_gate(
|
||||
RandomAccessGate::new_from_config(&self.config, bits),
|
||||
vec![],
|
||||
);
|
||||
(gate, 0)
|
||||
});
|
||||
|
||||
// Update `free_random_access` with new values.
|
||||
if i + 1 < RandomAccessGate::<F, D>::new_from_config(&self.config, bits).num_copies {
|
||||
self.batched_gates
|
||||
.free_random_access
|
||||
.insert(bits, (gate, i + 1));
|
||||
} else {
|
||||
self.batched_gates.free_random_access.remove(&bits);
|
||||
}
|
||||
|
||||
(gate, i)
|
||||
}
|
||||
|
||||
pub(crate) fn find_switch_gate(
|
||||
&mut self,
|
||||
chunk_size: usize,
|
||||
) -> (SwitchGate<F, D>, usize, usize) {
|
||||
if self.batched_gates.current_switch_gates.len() < chunk_size {
|
||||
self.batched_gates.current_switch_gates.extend(vec![
|
||||
None;
|
||||
chunk_size
|
||||
- self
|
||||
.batched_gates
|
||||
.current_switch_gates
|
||||
.len()
|
||||
]);
|
||||
}
|
||||
|
||||
let (gate, gate_index, next_copy) =
|
||||
match self.batched_gates.current_switch_gates[chunk_size - 1].clone() {
|
||||
None => {
|
||||
let gate = SwitchGate::<F, D>::new_from_config(&self.config, chunk_size);
|
||||
let gate_index = self.add_gate(gate.clone(), vec![]);
|
||||
(gate, gate_index, 0)
|
||||
}
|
||||
Some((gate, idx, next_copy)) => (gate, idx, next_copy),
|
||||
};
|
||||
|
||||
let num_copies = gate.num_copies;
|
||||
|
||||
if next_copy == num_copies - 1 {
|
||||
self.batched_gates.current_switch_gates[chunk_size - 1] = None;
|
||||
} else {
|
||||
self.batched_gates.current_switch_gates[chunk_size - 1] =
|
||||
Some((gate.clone(), gate_index, next_copy + 1));
|
||||
}
|
||||
|
||||
(gate, gate_index, next_copy)
|
||||
}
|
||||
|
||||
pub(crate) fn find_u32_arithmetic_gate(&mut self) -> (usize, usize) {
|
||||
let (gate_index, copy) = match self.batched_gates.current_u32_arithmetic_gate {
|
||||
None => {
|
||||
let gate = U32ArithmeticGate::new_from_config(&self.config);
|
||||
let gate_index = self.add_gate(gate, vec![]);
|
||||
(gate_index, 0)
|
||||
}
|
||||
Some((gate_index, copy)) => (gate_index, copy),
|
||||
};
|
||||
|
||||
if copy == U32ArithmeticGate::<F, D>::num_ops(&self.config) - 1 {
|
||||
self.batched_gates.current_u32_arithmetic_gate = None;
|
||||
} else {
|
||||
self.batched_gates.current_u32_arithmetic_gate = Some((gate_index, copy + 1));
|
||||
}
|
||||
|
||||
(gate_index, copy)
|
||||
}
|
||||
|
||||
pub(crate) fn find_u32_subtraction_gate(&mut self) -> (usize, usize) {
|
||||
let (gate_index, copy) = match self.batched_gates.current_u32_subtraction_gate {
|
||||
None => {
|
||||
let gate = U32SubtractionGate::new_from_config(&self.config);
|
||||
let gate_index = self.add_gate(gate, vec![]);
|
||||
(gate_index, 0)
|
||||
}
|
||||
Some((gate_index, copy)) => (gate_index, copy),
|
||||
};
|
||||
|
||||
if copy == U32SubtractionGate::<F, D>::num_ops(&self.config) - 1 {
|
||||
self.batched_gates.current_u32_subtraction_gate = None;
|
||||
} else {
|
||||
self.batched_gates.current_u32_subtraction_gate = Some((gate_index, copy + 1));
|
||||
}
|
||||
|
||||
(gate_index, copy)
|
||||
}
|
||||
|
||||
/// Returns the gate index and copy index of a free `ConstantGate` slot, potentially adding a
|
||||
/// new `ConstantGate` if needed.
|
||||
fn constant_gate_instance(&mut self) -> (usize, usize) {
|
||||
if self.batched_gates.free_constant.is_none() {
|
||||
let num_consts = self.config.constant_gate_size;
|
||||
// We will fill this `ConstantGate` with zero constants initially.
|
||||
// These will be overwritten by `constant` as the gate instances are filled.
|
||||
let gate = self.add_gate(ConstantGate { num_consts }, vec![F::ZERO; num_consts]);
|
||||
self.batched_gates.free_constant = Some((gate, 0));
|
||||
}
|
||||
|
||||
let (gate, instance) = self.batched_gates.free_constant.unwrap();
|
||||
if instance + 1 < self.config.constant_gate_size {
|
||||
self.batched_gates.free_constant = Some((gate, instance + 1));
|
||||
} else {
|
||||
self.batched_gates.free_constant = None;
|
||||
}
|
||||
(gate, instance)
|
||||
}
|
||||
|
||||
/// Fill the remaining unused arithmetic operations with zeros, so that all
|
||||
/// `ArithmeticGate` are run.
|
||||
fn fill_base_arithmetic_gates(&mut self) {
|
||||
let zero = self.zero();
|
||||
for ((c0, c1), (_gate, i)) in self.batched_gates.free_base_arithmetic.clone() {
|
||||
for _ in i..ArithmeticGate::num_ops(&self.config) {
|
||||
// If we directly wire in zero, an optimization will skip doing anything and return
|
||||
// zero. So we pass in a virtual target and connect it to zero afterward.
|
||||
let dummy = self.add_virtual_target();
|
||||
self.arithmetic(c0, c1, dummy, dummy, dummy);
|
||||
self.connect(dummy, zero);
|
||||
}
|
||||
}
|
||||
assert!(self.batched_gates.free_base_arithmetic.is_empty());
|
||||
}
|
||||
|
||||
/// Fill the remaining unused arithmetic operations with zeros, so that all
|
||||
/// `ArithmeticExtensionGenerator`s are run.
|
||||
fn fill_arithmetic_gates(&mut self) {
|
||||
let zero = self.zero_extension();
|
||||
for ((c0, c1), (_gate, i)) in self.batched_gates.free_arithmetic.clone() {
|
||||
for _ in i..ArithmeticExtensionGate::<D>::num_ops(&self.config) {
|
||||
// If we directly wire in zero, an optimization will skip doing anything and return
|
||||
// zero. So we pass in a virtual target and connect it to zero afterward.
|
||||
let dummy = self.add_virtual_extension_target();
|
||||
self.arithmetic_extension(c0, c1, dummy, dummy, dummy);
|
||||
self.connect_extension(dummy, zero);
|
||||
}
|
||||
}
|
||||
assert!(self.batched_gates.free_arithmetic.is_empty());
|
||||
}
|
||||
|
||||
/// Fill the remaining unused arithmetic operations with zeros, so that all
|
||||
/// `ArithmeticExtensionGenerator`s are run.
|
||||
fn fill_mul_gates(&mut self) {
|
||||
let zero = self.zero_extension();
|
||||
for (c0, (_gate, i)) in self.batched_gates.free_mul.clone() {
|
||||
for _ in i..MulExtensionGate::<D>::num_ops(&self.config) {
|
||||
// If we directly wire in zero, an optimization will skip doing anything and return
|
||||
// zero. So we pass in a virtual target and connect it to zero afterward.
|
||||
let dummy = self.add_virtual_extension_target();
|
||||
self.arithmetic_extension(c0, F::ZERO, dummy, dummy, zero);
|
||||
self.connect_extension(dummy, zero);
|
||||
}
|
||||
}
|
||||
assert!(self.batched_gates.free_mul.is_empty());
|
||||
}
|
||||
|
||||
/// Fill the remaining unused random access operations with zeros, so that all
|
||||
/// `RandomAccessGenerator`s are run.
|
||||
fn fill_random_access_gates(&mut self) {
|
||||
let zero = self.zero();
|
||||
for (bits, (_, i)) in self.batched_gates.free_random_access.clone() {
|
||||
let max_copies =
|
||||
RandomAccessGate::<F, D>::new_from_config(&self.config, bits).num_copies;
|
||||
for _ in i..max_copies {
|
||||
self.random_access(zero, zero, vec![zero; 1 << bits]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fill the remaining unused switch gates with dummy values, so that all
|
||||
/// `SwitchGenerator`s are run.
|
||||
fn fill_switch_gates(&mut self) {
|
||||
let zero = self.zero();
|
||||
|
||||
for chunk_size in 1..=self.batched_gates.current_switch_gates.len() {
|
||||
if let Some((gate, gate_index, mut copy)) =
|
||||
self.batched_gates.current_switch_gates[chunk_size - 1].clone()
|
||||
{
|
||||
while copy < gate.num_copies {
|
||||
for element in 0..chunk_size {
|
||||
let wire_first_input =
|
||||
Target::wire(gate_index, gate.wire_first_input(copy, element));
|
||||
let wire_second_input =
|
||||
Target::wire(gate_index, gate.wire_second_input(copy, element));
|
||||
let wire_switch_bool =
|
||||
Target::wire(gate_index, gate.wire_switch_bool(copy));
|
||||
self.connect(zero, wire_first_input);
|
||||
self.connect(zero, wire_second_input);
|
||||
self.connect(zero, wire_switch_bool);
|
||||
}
|
||||
copy += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fill the remaining unused U32 arithmetic operations with zeros, so that all
|
||||
/// `U32ArithmeticGenerator`s are run.
|
||||
fn fill_u32_arithmetic_gates(&mut self) {
|
||||
let zero = self.zero_u32();
|
||||
if let Some((_gate_index, copy)) = self.batched_gates.current_u32_arithmetic_gate {
|
||||
for _ in copy..U32ArithmeticGate::<F, D>::num_ops(&self.config) {
|
||||
let dummy = self.add_virtual_u32_target();
|
||||
self.mul_add_u32(dummy, dummy, dummy);
|
||||
self.connect_u32(dummy, zero);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fill the remaining unused U32 subtraction operations with zeros, so that all
|
||||
/// `U32SubtractionGenerator`s are run.
|
||||
fn fill_u32_subtraction_gates(&mut self) {
|
||||
let zero = self.zero_u32();
|
||||
if let Some((_gate_index, copy)) = self.batched_gates.current_u32_subtraction_gate {
|
||||
for _i in copy..U32SubtractionGate::<F, D>::num_ops(&self.config) {
|
||||
let dummy = self.add_virtual_u32_target();
|
||||
self.sub_u32(dummy, dummy, dummy);
|
||||
self.connect_u32(dummy, zero);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fill_batched_gates(&mut self) {
|
||||
self.fill_arithmetic_gates();
|
||||
self.fill_base_arithmetic_gates();
|
||||
self.fill_mul_gates();
|
||||
self.fill_random_access_gates();
|
||||
self.fill_switch_gates();
|
||||
self.fill_u32_arithmetic_gates();
|
||||
self.fill_u32_subtraction_gates();
|
||||
}
|
||||
}
|
||||
|
||||
@ -16,6 +16,7 @@ use crate::iop::target::Target;
|
||||
use crate::iop::witness::PartialWitness;
|
||||
use crate::plonk::config::{GenericConfig, Hasher};
|
||||
use crate::plonk::proof::ProofWithPublicInputs;
|
||||
use crate::plonk::proof::{CompressedProofWithPublicInputs, ProofWithPublicInputs};
|
||||
use crate::plonk::prover::prove;
|
||||
use crate::plonk::verifier::verify;
|
||||
use crate::util::marking::MarkedTargets;
|
||||
@ -26,6 +27,9 @@ pub struct CircuitConfig {
|
||||
pub num_wires: usize,
|
||||
pub num_routed_wires: usize,
|
||||
pub constant_gate_size: usize,
|
||||
/// Whether to use a dedicated gate for base field arithmetic, rather than using a single gate
|
||||
/// for both base field and extension field arithmetic.
|
||||
pub use_base_arithmetic_gate: bool,
|
||||
pub security_bits: usize,
|
||||
pub rate_bits: usize,
|
||||
/// The number of challenge points to generate, for IOPs that have soundness errors of (roughly)
|
||||
@ -45,30 +49,35 @@ impl Default for CircuitConfig {
|
||||
}
|
||||
|
||||
impl CircuitConfig {
|
||||
pub fn rate(&self) -> f64 {
|
||||
1.0 / ((1 << self.rate_bits) as f64)
|
||||
}
|
||||
|
||||
pub fn num_advice_wires(&self) -> usize {
|
||||
self.num_wires - self.num_routed_wires
|
||||
}
|
||||
|
||||
/// A typical recursion config, without zero-knowledge, targeting ~100 bit security.
|
||||
pub(crate) fn standard_recursion_config() -> Self {
|
||||
pub fn standard_recursion_config() -> Self {
|
||||
Self {
|
||||
num_wires: 143,
|
||||
num_routed_wires: 25,
|
||||
constant_gate_size: 6,
|
||||
num_wires: 135,
|
||||
num_routed_wires: 80,
|
||||
constant_gate_size: 5,
|
||||
use_base_arithmetic_gate: true,
|
||||
security_bits: 100,
|
||||
rate_bits: 3,
|
||||
num_challenges: 2,
|
||||
zero_knowledge: false,
|
||||
cap_height: 3,
|
||||
cap_height: 4,
|
||||
fri_config: FriConfig {
|
||||
proof_of_work_bits: 16,
|
||||
reduction_strategy: FriReductionStrategy::ConstantArityBits(3, 5),
|
||||
reduction_strategy: FriReductionStrategy::ConstantArityBits(4, 5),
|
||||
num_query_rounds: 28,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn standard_recursion_zk_config() -> Self {
|
||||
pub fn standard_recursion_zk_config() -> Self {
|
||||
CircuitConfig {
|
||||
zero_knowledge: true,
|
||||
..Self::standard_recursion_config()
|
||||
@ -96,6 +105,13 @@ impl<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize> CircuitData<F
|
||||
pub fn verify(&self, proof_with_pis: ProofWithPublicInputs<F, C, D>) -> Result<()> {
|
||||
verify(proof_with_pis, &self.verifier_only, &self.common)
|
||||
}
|
||||
|
||||
pub fn verify_compressed(
|
||||
&self,
|
||||
compressed_proof_with_pis: CompressedProofWithPublicInputs<F, D>,
|
||||
) -> Result<()> {
|
||||
compressed_proof_with_pis.verify(&self.verifier_only, &self.common)
|
||||
}
|
||||
}
|
||||
|
||||
/// Circuit data required by the prover. This may be thought of as a proving key, although it
|
||||
@ -132,6 +148,13 @@ impl<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize> VerifierCircu
|
||||
pub fn verify(&self, proof_with_pis: ProofWithPublicInputs<F, C, D>) -> Result<()> {
|
||||
verify(proof_with_pis, &self.verifier_only, &self.common)
|
||||
}
|
||||
|
||||
pub fn verify_compressed(
|
||||
&self,
|
||||
compressed_proof_with_pis: CompressedProofWithPublicInputs<F, D>,
|
||||
) -> Result<()> {
|
||||
compressed_proof_with_pis.verify(&self.verifier_only, &self.common)
|
||||
}
|
||||
}
|
||||
|
||||
/// Circuit data required by the prover, but not the verifier.
|
||||
@ -194,8 +217,8 @@ pub struct CommonCircuitData<F: Extendable<D>, C: GenericConfig<D, F = F>, const
|
||||
/// The `{k_i}` valued used in `S_ID_i` in Plonk's permutation argument.
|
||||
pub(crate) k_is: Vec<F>,
|
||||
|
||||
/// The number of partial products needed to compute the `Z` polynomials and the number
|
||||
/// of partial products needed to compute the final product.
|
||||
/// The number of partial products needed to compute the `Z` polynomials and
|
||||
/// the number of original elements consumed in `partial_products()`.
|
||||
pub(crate) num_partial_products: (usize, usize),
|
||||
|
||||
/// A digest of the "circuit" (i.e. the instance, minus public inputs), which can be used to
|
||||
@ -228,11 +251,6 @@ impl<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize> CommonCircuit
|
||||
self.quotient_degree_factor * self.degree()
|
||||
}
|
||||
|
||||
pub fn total_constraints(&self) -> usize {
|
||||
// 2 constraints for each Z check.
|
||||
self.config.num_challenges * 2 + self.num_gate_constraints
|
||||
}
|
||||
|
||||
/// Range of the constants polynomials in the `constants_sigmas_commitment`.
|
||||
pub fn constants_range(&self) -> Range<usize> {
|
||||
0..self.num_constants
|
||||
|
||||
@ -11,7 +11,7 @@ use crate::plonk::proof::{
|
||||
CompressedProof, CompressedProofWithPublicInputs, FriInferredElements, OpeningSet, Proof,
|
||||
ProofChallenges, ProofWithPublicInputs,
|
||||
};
|
||||
use crate::polynomial::polynomial::PolynomialCoeffs;
|
||||
use crate::polynomial::PolynomialCoeffs;
|
||||
use crate::util::reverse_bits;
|
||||
|
||||
fn get_challenges<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>(
|
||||
|
||||
@ -5,7 +5,7 @@ use rayon::prelude::*;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::iop::target::Target;
|
||||
use crate::iop::wire::Wire;
|
||||
use crate::polynomial::polynomial::PolynomialValues;
|
||||
use crate::polynomial::PolynomialValues;
|
||||
|
||||
/// Disjoint Set Forest data-structure following https://en.wikipedia.org/wiki/Disjoint-set_data_structure.
|
||||
pub struct Forest {
|
||||
@ -45,15 +45,23 @@ impl Forest {
|
||||
}
|
||||
|
||||
/// Path compression method, see https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Finding_set_representatives.
|
||||
pub fn find(&mut self, x_index: usize) -> usize {
|
||||
let x_parent = self.parents[x_index];
|
||||
if x_parent != x_index {
|
||||
let root_index = self.find(x_parent);
|
||||
self.parents[x_index] = root_index;
|
||||
root_index
|
||||
} else {
|
||||
x_index
|
||||
pub fn find(&mut self, mut x_index: usize) -> usize {
|
||||
// Note: We avoid recursion here since the chains can be long, causing stack overflows.
|
||||
|
||||
// First, find the representative of the set containing `x_index`.
|
||||
let mut representative = x_index;
|
||||
while self.parents[representative] != representative {
|
||||
representative = self.parents[representative];
|
||||
}
|
||||
|
||||
// Then, update each node in this chain to point directly to the representative.
|
||||
while self.parents[x_index] != x_index {
|
||||
let old_parent = self.parents[x_index];
|
||||
self.parents[x_index] = representative;
|
||||
x_index = old_parent;
|
||||
}
|
||||
|
||||
representative
|
||||
}
|
||||
|
||||
/// Merge two sets.
|
||||
|
||||
@ -42,17 +42,6 @@ impl PlonkPolynomials {
|
||||
index: 3,
|
||||
blinding: true,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn polynomials(i: usize) -> PolynomialsIndexBlinding {
|
||||
match i {
|
||||
0 => Self::CONSTANTS_SIGMAS,
|
||||
1 => Self::WIRES,
|
||||
2 => Self::ZS_PARTIAL_PRODUCTS,
|
||||
3 => Self::QUOTIENT,
|
||||
_ => panic!("There are only 4 sets of polynomials in Plonk."),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate the polynomial which vanishes on any multiplicative subgroup of a given order `n`.
|
||||
|
||||
@ -164,12 +164,12 @@ impl<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
|
||||
) -> anyhow::Result<ProofWithPublicInputs<F, C, D>> {
|
||||
let challenges = self.get_challenges(common_data)?;
|
||||
let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data);
|
||||
let compressed_proof =
|
||||
let decompressed_proof =
|
||||
self.proof
|
||||
.decompress(&challenges, fri_inferred_elements, common_data);
|
||||
Ok(ProofWithPublicInputs {
|
||||
public_inputs: self.public_inputs,
|
||||
proof: compressed_proof,
|
||||
proof: decompressed_proof,
|
||||
})
|
||||
}
|
||||
|
||||
@ -180,13 +180,13 @@ impl<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
|
||||
) -> anyhow::Result<()> {
|
||||
let challenges = self.get_challenges(common_data)?;
|
||||
let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data);
|
||||
let compressed_proof =
|
||||
let decompressed_proof =
|
||||
self.proof
|
||||
.decompress(&challenges, fri_inferred_elements, common_data);
|
||||
verify_with_challenges(
|
||||
ProofWithPublicInputs {
|
||||
public_inputs: self.public_inputs,
|
||||
proof: compressed_proof,
|
||||
proof: decompressed_proof,
|
||||
},
|
||||
challenges,
|
||||
verifier_data,
|
||||
@ -312,6 +312,7 @@ mod tests {
|
||||
|
||||
use crate::field::field_types::Field;
|
||||
use crate::fri::reduction_strategies::FriReductionStrategy;
|
||||
use crate::gates::noop::NoopGate;
|
||||
use crate::iop::witness::PartialWitness;
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
use crate::plonk::circuit_data::CircuitConfig;
|
||||
@ -340,6 +341,9 @@ mod tests {
|
||||
let zt = builder.constant(z);
|
||||
let comp_zt = builder.mul(xt, yt);
|
||||
builder.connect(zt, comp_zt);
|
||||
for _ in 0..100 {
|
||||
builder.add_gate(NoopGate, vec![]);
|
||||
}
|
||||
let data = builder.build::<C>();
|
||||
let proof = data.prove(pw)?;
|
||||
verify(proof.clone(), &data.verifier_only, &data.common)?;
|
||||
@ -350,6 +354,6 @@ mod tests {
|
||||
assert_eq!(proof, decompressed_compressed_proof);
|
||||
|
||||
verify(proof, &data.verifier_only, &data.common)?;
|
||||
compressed_proof.verify(&data.verifier_only, &data.common)
|
||||
data.verify_compressed(compressed_proof)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
use std::mem::swap;
|
||||
|
||||
use anyhow::Result;
|
||||
use rayon::prelude::*;
|
||||
|
||||
@ -13,9 +15,9 @@ use crate::plonk::plonk_common::ZeroPolyOnCoset;
|
||||
use crate::plonk::proof::{Proof, ProofWithPublicInputs};
|
||||
use crate::plonk::vanishing_poly::eval_vanishing_poly_base_batch;
|
||||
use crate::plonk::vars::EvaluationVarsBase;
|
||||
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
|
||||
use crate::polynomial::{PolynomialCoeffs, PolynomialValues};
|
||||
use crate::timed;
|
||||
use crate::util::partial_products::partial_products;
|
||||
use crate::util::partial_products::{partial_products_and_z_gx, quotient_chunk_products};
|
||||
use crate::util::timing::TimingTree;
|
||||
use crate::util::{log2_ceil, transpose};
|
||||
|
||||
@ -89,28 +91,22 @@ pub(crate) fn prove<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize
|
||||
common_data.quotient_degree_factor < common_data.config.num_routed_wires,
|
||||
"When the number of routed wires is smaller that the degree, we should change the logic to avoid computing partial products."
|
||||
);
|
||||
let mut partial_products = timed!(
|
||||
let mut partial_products_and_zs = timed!(
|
||||
timing,
|
||||
"compute partial products",
|
||||
all_wires_permutation_partial_products(&witness, &betas, &gammas, prover_data, common_data)
|
||||
);
|
||||
|
||||
let plonk_z_vecs = timed!(
|
||||
timing,
|
||||
"compute Z's",
|
||||
compute_zs(&partial_products, common_data)
|
||||
);
|
||||
// Z is expected at the front of our batch; see `zs_range` and `partial_products_range`.
|
||||
let plonk_z_vecs = partial_products_and_zs
|
||||
.iter_mut()
|
||||
.map(|partial_products_and_z| partial_products_and_z.pop().unwrap())
|
||||
.collect();
|
||||
let zs_partial_products = [plonk_z_vecs, partial_products_and_zs.concat()].concat();
|
||||
|
||||
// The first polynomial in `partial_products` represent the final product used in the
|
||||
// computation of `Z`. It isn't needed anymore so we discard it.
|
||||
partial_products.iter_mut().for_each(|part| {
|
||||
part.remove(0);
|
||||
});
|
||||
|
||||
let zs_partial_products = [plonk_z_vecs, partial_products.concat()].concat();
|
||||
let zs_partial_products_commitment = timed!(
|
||||
let partial_products_and_zs_commitment = timed!(
|
||||
timing,
|
||||
"commit to Z's",
|
||||
"commit to partial products and Z's",
|
||||
PolynomialBatchCommitment::from_values(
|
||||
zs_partial_products,
|
||||
config.rate_bits,
|
||||
@ -121,7 +117,7 @@ pub(crate) fn prove<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize
|
||||
)
|
||||
);
|
||||
|
||||
challenger.observe_cap(&zs_partial_products_commitment.merkle_tree.cap);
|
||||
challenger.observe_cap(&partial_products_and_zs_commitment.merkle_tree.cap);
|
||||
|
||||
let alphas = challenger.get_n_challenges(num_challenges);
|
||||
|
||||
@ -133,7 +129,7 @@ pub(crate) fn prove<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize
|
||||
prover_data,
|
||||
&public_inputs_hash,
|
||||
&wires_commitment,
|
||||
&zs_partial_products_commitment,
|
||||
&partial_products_and_zs_commitment,
|
||||
&betas,
|
||||
&gammas,
|
||||
&alphas,
|
||||
@ -148,7 +144,6 @@ pub(crate) fn prove<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize
|
||||
.into_par_iter()
|
||||
.flat_map(|mut quotient_poly| {
|
||||
quotient_poly.trim();
|
||||
// TODO: Return Result instead of panicking.
|
||||
quotient_poly.pad(quotient_degree).expect(
|
||||
"Quotient has failed, the vanishing polynomial is not divisible by `Z_H",
|
||||
);
|
||||
@ -182,7 +177,7 @@ pub(crate) fn prove<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize
|
||||
&[
|
||||
&prover_data.constants_sigmas_commitment,
|
||||
&wires_commitment,
|
||||
&zs_partial_products_commitment,
|
||||
&partial_products_and_zs_commitment,
|
||||
"ient_polys_commitment,
|
||||
],
|
||||
zeta,
|
||||
@ -194,7 +189,7 @@ pub(crate) fn prove<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize
|
||||
|
||||
let proof = Proof {
|
||||
wires_cap: wires_commitment.merkle_tree.cap,
|
||||
plonk_zs_partial_products_cap: zs_partial_products_commitment.merkle_tree.cap,
|
||||
plonk_zs_partial_products_cap: partial_products_and_zs_commitment.merkle_tree.cap,
|
||||
quotient_polys_cap: quotient_polys_commitment.merkle_tree.cap,
|
||||
openings,
|
||||
opening_proof,
|
||||
@ -219,7 +214,7 @@ fn all_wires_permutation_partial_products<
|
||||
) -> Vec<Vec<PolynomialValues<F>>> {
|
||||
(0..common_data.config.num_challenges)
|
||||
.map(|i| {
|
||||
wires_permutation_partial_products(
|
||||
wires_permutation_partial_products_and_zs(
|
||||
witness,
|
||||
betas[i],
|
||||
gammas[i],
|
||||
@ -233,7 +228,7 @@ fn all_wires_permutation_partial_products<
|
||||
/// Compute the partial products used in the `Z` polynomial.
|
||||
/// Returns the polynomials interpolating `partial_products(f / g)`
|
||||
/// where `f, g` are the products in the definition of `Z`: `Z(g^i) = f / g`.
|
||||
fn wires_permutation_partial_products<
|
||||
fn wires_permutation_partial_products_and_zs<
|
||||
F: Extendable<D>,
|
||||
C: GenericConfig<D, F = F>,
|
||||
const D: usize,
|
||||
@ -247,7 +242,8 @@ fn wires_permutation_partial_products<
|
||||
let degree = common_data.quotient_degree_factor;
|
||||
let subgroup = &prover_data.subgroup;
|
||||
let k_is = &common_data.k_is;
|
||||
let values = subgroup
|
||||
let (num_prods, _final_num_prod) = common_data.num_partial_products;
|
||||
let all_quotient_chunk_products = subgroup
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.map(|(i, &x)| {
|
||||
@ -271,49 +267,26 @@ fn wires_permutation_partial_products<
|
||||
.map(|(num, den_inv)| num * den_inv)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let quotient_partials = partial_products("ient_values, degree);
|
||||
|
||||
// This is the final product for the quotient.
|
||||
let quotient = quotient_partials
|
||||
[common_data.num_partial_products.0 - common_data.num_partial_products.1..]
|
||||
.iter()
|
||||
.copied()
|
||||
.product();
|
||||
|
||||
// We add the quotient at the beginning of the vector to reuse them later in the computation of `Z`.
|
||||
[vec![quotient], quotient_partials].concat()
|
||||
quotient_chunk_products("ient_values, degree)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
transpose(&values)
|
||||
let mut z_x = F::ONE;
|
||||
let mut all_partial_products_and_zs = Vec::new();
|
||||
for quotient_chunk_products in all_quotient_chunk_products {
|
||||
let mut partial_products_and_z_gx =
|
||||
partial_products_and_z_gx(z_x, "ient_chunk_products);
|
||||
// The last term is Z(gx), but we replace it with Z(x), otherwise Z would end up shifted.
|
||||
swap(&mut z_x, &mut partial_products_and_z_gx[num_prods]);
|
||||
all_partial_products_and_zs.push(partial_products_and_z_gx);
|
||||
}
|
||||
|
||||
transpose(&all_partial_products_and_zs)
|
||||
.into_par_iter()
|
||||
.map(PolynomialValues::new)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn compute_zs<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>(
|
||||
partial_products: &[Vec<PolynomialValues<F>>],
|
||||
common_data: &CommonCircuitData<F, C, D>,
|
||||
) -> Vec<PolynomialValues<F>> {
|
||||
(0..common_data.config.num_challenges)
|
||||
.map(|i| compute_z(&partial_products[i], common_data))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute the `Z` polynomial by reusing the computations done in `wires_permutation_partial_products`.
|
||||
fn compute_z<F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>(
|
||||
partial_products: &[PolynomialValues<F>],
|
||||
common_data: &CommonCircuitData<F, C, D>,
|
||||
) -> PolynomialValues<F> {
|
||||
let mut plonk_z_points = vec![F::ONE];
|
||||
for i in 1..common_data.degree() {
|
||||
let quotient = partial_products[0].values[i - 1];
|
||||
let last = *plonk_z_points.last().unwrap();
|
||||
plonk_z_points.push(last * quotient);
|
||||
}
|
||||
plonk_z_points.into()
|
||||
}
|
||||
|
||||
const BATCH_SIZE: usize = 32;
|
||||
|
||||
fn compute_quotient_polys<'a, F: Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>(
|
||||
|
||||
@ -133,6 +133,7 @@ mod tests {
|
||||
use crate::fri::reduction_strategies::FriReductionStrategy;
|
||||
use crate::fri::FriConfig;
|
||||
use crate::gadgets::polynomial::PolynomialCoeffsExtTarget;
|
||||
use crate::gates::noop::NoopGate;
|
||||
use crate::hash::merkle_proofs::MerkleProofTarget;
|
||||
use crate::iop::witness::{PartialWitness, Witness};
|
||||
use crate::plonk::circuit_data::VerifierOnlyCircuitData;
|
||||
@ -369,9 +370,8 @@ mod tests {
|
||||
type F = <C as GenericConfig<D>>::F;
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
|
||||
let (proof, vd, cd) = dummy_proof::<F, C, D>(&config, 8_000)?;
|
||||
let (proof, _vd, cd) =
|
||||
recursive_proof::<F, C, C, D>(proof, vd, cd, &config, &config, true, true)?;
|
||||
let (proof, vd, cd) = dummy_proof::<F,C, D>(&config, 4_000)?;
|
||||
let (proof, _vd, cd) = recursive_proof::<F,C,C,D>(proof, vd, cd, &config, &config, None, true, true)?;
|
||||
test_serialization(&proof, &cd)?;
|
||||
|
||||
Ok(())
|
||||
@ -388,11 +388,14 @@ mod tests {
|
||||
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
|
||||
let (proof, vd, cd) = dummy_proof::<F, C, D>(&config, 8_000)?;
|
||||
// Start with a degree 2^14 proof, then shrink it to 2^13, then to 2^12.
|
||||
let (proof, vd, cd) = dummy_proof::<F,C, D>(&config, 16_000)?;
|
||||
assert_eq!(cd.degree_bits, 14);
|
||||
let (proof, vd, cd) =
|
||||
recursive_proof::<F, C, C, D>(proof, vd, cd, &config, &config, false, false)?;
|
||||
let (proof, _vd, cd) =
|
||||
recursive_proof::<F, KC, C, D>(proof, vd, cd, &config, &config, true, true)?;
|
||||
recursive_proof::<F,C,C,D>(proof, vd, cd, &config, &config, Some(13), false, false)?;
|
||||
assert_eq!(cd.degree_bits, 13);
|
||||
let (proof, _vd, cd) = recursive_proof::<F,KC,C,D>(proof, vd, cd, &config, &config, None, true, true)?;
|
||||
assert_eq!(cd.degree_bits, 12);
|
||||
|
||||
test_serialization(&proof, &cd)?;
|
||||
|
||||
@ -412,29 +415,29 @@ mod tests {
|
||||
|
||||
let standard_config = CircuitConfig::standard_recursion_config();
|
||||
|
||||
// A dummy proof with degree 2^13.
|
||||
let (proof, vd, cd) = dummy_proof::<F, C, D>(&standard_config, 8_000)?;
|
||||
assert_eq!(cd.degree_bits, 13);
|
||||
// An initial dummy proof.
|
||||
let (proof, vd, cd) = dummy_proof::<F,C, D>(&standard_config, 4_000)?;
|
||||
assert_eq!(cd.degree_bits, 12);
|
||||
|
||||
// A standard recursive proof with degree 2^13.
|
||||
// A standard recursive proof.
|
||||
let (proof, vd, cd) = recursive_proof(
|
||||
proof,
|
||||
vd,
|
||||
cd,
|
||||
&standard_config,
|
||||
&standard_config,
|
||||
None,
|
||||
false,
|
||||
false,
|
||||
)?;
|
||||
assert_eq!(cd.degree_bits, 13);
|
||||
assert_eq!(cd.degree_bits, 12);
|
||||
|
||||
// A high-rate recursive proof with degree 2^13, designed to be verifiable with 2^12
|
||||
// gates and 48 routed wires.
|
||||
// A high-rate recursive proof, designed to be verifiable with fewer routed wires.
|
||||
let high_rate_config = CircuitConfig {
|
||||
rate_bits: 5,
|
||||
rate_bits: 7,
|
||||
fri_config: FriConfig {
|
||||
proof_of_work_bits: 20,
|
||||
num_query_rounds: 16,
|
||||
proof_of_work_bits: 16,
|
||||
num_query_rounds: 12,
|
||||
..standard_config.fri_config.clone()
|
||||
},
|
||||
..standard_config
|
||||
@ -445,54 +448,35 @@ mod tests {
|
||||
cd,
|
||||
&standard_config,
|
||||
&high_rate_config,
|
||||
true,
|
||||
true,
|
||||
)?;
|
||||
assert_eq!(cd.degree_bits, 13);
|
||||
|
||||
// A higher-rate recursive proof with degree 2^12, designed to be verifiable with 2^12
|
||||
// gates and 28 routed wires.
|
||||
let higher_rate_more_routing_config = CircuitConfig {
|
||||
rate_bits: 7,
|
||||
num_routed_wires: 48,
|
||||
fri_config: FriConfig {
|
||||
proof_of_work_bits: 23,
|
||||
num_query_rounds: 11,
|
||||
..standard_config.fri_config.clone()
|
||||
},
|
||||
..high_rate_config.clone()
|
||||
};
|
||||
let (proof, vd, cd) = recursive_proof::<F, C, C, D>(
|
||||
proof,
|
||||
vd,
|
||||
cd,
|
||||
&high_rate_config,
|
||||
&higher_rate_more_routing_config,
|
||||
None,
|
||||
true,
|
||||
true,
|
||||
)?;
|
||||
assert_eq!(cd.degree_bits, 12);
|
||||
|
||||
// A final proof of degree 2^12, optimized for size.
|
||||
// A final proof, optimized for size.
|
||||
let final_config = CircuitConfig {
|
||||
cap_height: 0,
|
||||
num_routed_wires: 32,
|
||||
rate_bits: 8,
|
||||
num_routed_wires: 37,
|
||||
fri_config: FriConfig {
|
||||
proof_of_work_bits: 20,
|
||||
reduction_strategy: FriReductionStrategy::MinSize(None),
|
||||
..higher_rate_more_routing_config.fri_config.clone()
|
||||
num_query_rounds: 10,
|
||||
},
|
||||
..higher_rate_more_routing_config
|
||||
..high_rate_config
|
||||
};
|
||||
let (proof, _vd, cd) = recursive_proof::<F, KC, C, D>(
|
||||
proof,
|
||||
vd,
|
||||
cd,
|
||||
&higher_rate_more_routing_config,
|
||||
&high_rate_config,
|
||||
&final_config,
|
||||
None,
|
||||
true,
|
||||
true,
|
||||
)?;
|
||||
assert_eq!(cd.degree_bits, 12);
|
||||
assert_eq!(cd.degree_bits, 12, "final proof too large");
|
||||
|
||||
test_serialization(&proof, &cd)?;
|
||||
|
||||
@ -509,16 +493,12 @@ mod tests {
|
||||
CommonCircuitData<F, C, D>,
|
||||
)> {
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config.clone());
|
||||
let input = builder.add_virtual_target();
|
||||
for i in 0..num_dummy_gates {
|
||||
// Use unique constants to force a new `ArithmeticGate`.
|
||||
let i_f = F::from_canonical_u64(i);
|
||||
builder.arithmetic(i_f, i_f, input, input, input);
|
||||
for _ in 0..num_dummy_gates {
|
||||
builder.add_gate(NoopGate, vec![]);
|
||||
}
|
||||
|
||||
let data = builder.build::<C>();
|
||||
let mut inputs = PartialWitness::new();
|
||||
inputs.set_target(input, F::ZERO);
|
||||
let inputs = PartialWitness::new();
|
||||
let proof = data.prove(inputs)?;
|
||||
data.verify(proof.clone())?;
|
||||
|
||||
@ -536,6 +516,7 @@ mod tests {
|
||||
inner_cd: CommonCircuitData<F, InnerC, D>,
|
||||
inner_config: &CircuitConfig,
|
||||
config: &CircuitConfig,
|
||||
min_degree_bits: Option<usize>,
|
||||
print_gate_counts: bool,
|
||||
print_timing: bool,
|
||||
) -> Result<(
|
||||
@ -556,12 +537,22 @@ mod tests {
|
||||
&inner_vd.constants_sigmas_cap,
|
||||
);
|
||||
|
||||
builder.add_recursive_verifier(pt, &inner_config, &inner_data, &inner_cd);
|
||||
builder.add_recursive_verifier(pt, inner_config, &inner_data, &inner_cd);
|
||||
|
||||
if print_gate_counts {
|
||||
builder.print_gate_counts(0);
|
||||
}
|
||||
|
||||
if let Some(min_degree_bits) = min_degree_bits {
|
||||
// We don't want to pad all the way up to 2^min_degree_bits, as the builder will add a
|
||||
// few special gates afterward. So just pad to 2^(min_degree_bits - 1) + 1. Then the
|
||||
// builder will pad to the next power of two, 2^min_degree_bits.
|
||||
let min_gates = (1 << (min_degree_bits - 1)) + 1;
|
||||
for _ in builder.num_gates()..min_gates {
|
||||
builder.add_gate(NoopGate, vec![]);
|
||||
}
|
||||
}
|
||||
|
||||
let data = builder.build::<C>();
|
||||
|
||||
let mut timing = TimingTree::new("prove", Level::Debug);
|
||||
@ -582,12 +573,12 @@ mod tests {
|
||||
) -> Result<()> {
|
||||
let proof_bytes = proof.to_bytes()?;
|
||||
info!("Proof length: {} bytes", proof_bytes.len());
|
||||
let proof_from_bytes = ProofWithPublicInputs::from_bytes(proof_bytes, &cd)?;
|
||||
let proof_from_bytes = ProofWithPublicInputs::from_bytes(proof_bytes, cd)?;
|
||||
assert_eq!(proof, &proof_from_bytes);
|
||||
|
||||
let now = std::time::Instant::now();
|
||||
let compressed_proof = proof.clone().compress(&cd)?;
|
||||
let decompressed_compressed_proof = compressed_proof.clone().decompress(&cd)?;
|
||||
let compressed_proof = proof.clone().compress(cd)?;
|
||||
let decompressed_compressed_proof = compressed_proof.clone().decompress(cd)?;
|
||||
info!("{:.4}s to compress proof", now.elapsed().as_secs_f64());
|
||||
assert_eq!(proof, &decompressed_compressed_proof);
|
||||
|
||||
@ -597,7 +588,7 @@ mod tests {
|
||||
compressed_proof_bytes.len()
|
||||
);
|
||||
let compressed_proof_from_bytes =
|
||||
CompressedProofWithPublicInputs::from_bytes(compressed_proof_bytes, &cd)?;
|
||||
CompressedProofWithPublicInputs::from_bytes(compressed_proof_bytes, cd)?;
|
||||
assert_eq!(compressed_proof, compressed_proof_from_bytes);
|
||||
|
||||
Ok(())
|
||||
|
||||
@ -29,7 +29,7 @@ pub(crate) fn eval_vanishing_poly<F: Extendable<D>, C: GenericConfig<D, F = F>,
|
||||
alphas: &[F],
|
||||
) -> Vec<F::Extension> {
|
||||
let max_degree = common_data.quotient_degree_factor;
|
||||
let (num_prods, final_num_prod) = common_data.num_partial_products;
|
||||
let (num_prods, _final_num_prod) = common_data.num_partial_products;
|
||||
|
||||
let constraint_terms =
|
||||
evaluate_gate_constraints(&common_data.gates, common_data.num_gate_constraints, vars);
|
||||
@ -38,14 +38,12 @@ pub(crate) fn eval_vanishing_poly<F: Extendable<D>, C: GenericConfig<D, F = F>,
|
||||
let mut vanishing_z_1_terms = Vec::new();
|
||||
// The terms checking the partial products.
|
||||
let mut vanishing_partial_products_terms = Vec::new();
|
||||
// The Z(x) f'(x) - g'(x) Z(g x) terms.
|
||||
let mut vanishing_v_shift_terms = Vec::new();
|
||||
|
||||
let l1_x = plonk_common::eval_l_1(common_data.degree(), x);
|
||||
|
||||
for i in 0..common_data.config.num_challenges {
|
||||
let z_x = local_zs[i];
|
||||
let z_gz = next_zs[i];
|
||||
let z_gx = next_zs[i];
|
||||
vanishing_z_1_terms.push(l1_x * (z_x - F::Extension::ONE));
|
||||
|
||||
let numerator_values = (0..common_data.config.num_routed_wires)
|
||||
@ -63,37 +61,24 @@ pub(crate) fn eval_vanishing_poly<F: Extendable<D>, C: GenericConfig<D, F = F>,
|
||||
wire_value + s_sigma.scalar_mul(betas[i]) + gammas[i].into()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let quotient_values = (0..common_data.config.num_routed_wires)
|
||||
.map(|j| numerator_values[j] / denominator_values[j])
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// The partial products considered for this iteration of `i`.
|
||||
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
|
||||
// Check the quotient partial products.
|
||||
let mut partial_product_check =
|
||||
check_partial_products("ient_values, current_partial_products, max_degree);
|
||||
// The first checks are of the form `q - n/d` which is a rational function not a polynomial.
|
||||
// We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials.
|
||||
denominator_values
|
||||
.chunks(max_degree)
|
||||
.zip(partial_product_check.iter_mut())
|
||||
.for_each(|(d, q)| {
|
||||
*q *= d.iter().copied().product();
|
||||
});
|
||||
vanishing_partial_products_terms.extend(partial_product_check);
|
||||
|
||||
// The quotient final product is the product of the last `final_num_prod` elements.
|
||||
let quotient: F::Extension = current_partial_products[num_prods - final_num_prod..]
|
||||
.iter()
|
||||
.copied()
|
||||
.product();
|
||||
vanishing_v_shift_terms.push(quotient * z_x - z_gz);
|
||||
let partial_product_checks = check_partial_products(
|
||||
&numerator_values,
|
||||
&denominator_values,
|
||||
current_partial_products,
|
||||
z_x,
|
||||
z_gx,
|
||||
max_degree,
|
||||
);
|
||||
vanishing_partial_products_terms.extend(partial_product_checks);
|
||||
}
|
||||
|
||||
let vanishing_terms = [
|
||||
vanishing_z_1_terms,
|
||||
vanishing_partial_products_terms,
|
||||
vanishing_v_shift_terms,
|
||||
constraint_terms,
|
||||
]
|
||||
.concat();
|
||||
@ -130,7 +115,7 @@ pub(crate) fn eval_vanishing_poly_base_batch<
|
||||
assert_eq!(s_sigmas_batch.len(), n);
|
||||
|
||||
let max_degree = common_data.quotient_degree_factor;
|
||||
let (num_prods, final_num_prod) = common_data.num_partial_products;
|
||||
let (num_prods, _final_num_prod) = common_data.num_partial_products;
|
||||
|
||||
let num_gate_constraints = common_data.num_gate_constraints;
|
||||
|
||||
@ -143,14 +128,11 @@ pub(crate) fn eval_vanishing_poly_base_batch<
|
||||
|
||||
let mut numerator_values = Vec::with_capacity(num_routed_wires);
|
||||
let mut denominator_values = Vec::with_capacity(num_routed_wires);
|
||||
let mut quotient_values = Vec::with_capacity(num_routed_wires);
|
||||
|
||||
// The L_1(x) (Z(x) - 1) vanishing terms.
|
||||
let mut vanishing_z_1_terms = Vec::with_capacity(num_challenges);
|
||||
// The terms checking the partial products.
|
||||
let mut vanishing_partial_products_terms = Vec::new();
|
||||
// The Z(x) f'(x) - g'(x) Z(g x) terms.
|
||||
let mut vanishing_v_shift_terms = Vec::with_capacity(num_challenges);
|
||||
|
||||
let mut res_batch: Vec<Vec<F>> = Vec::with_capacity(n);
|
||||
for k in 0..n {
|
||||
@ -168,7 +150,7 @@ pub(crate) fn eval_vanishing_poly_base_batch<
|
||||
let l1_x = z_h_on_coset.eval_l1(index, x);
|
||||
for i in 0..num_challenges {
|
||||
let z_x = local_zs[i];
|
||||
let z_gz = next_zs[i];
|
||||
let z_gx = next_zs[i];
|
||||
vanishing_z_1_terms.push(l1_x * z_x.sub_one());
|
||||
|
||||
numerator_values.extend((0..num_routed_wires).map(|j| {
|
||||
@ -182,49 +164,33 @@ pub(crate) fn eval_vanishing_poly_base_batch<
|
||||
let s_sigma = s_sigmas[j];
|
||||
wire_value + betas[i] * s_sigma + gammas[i]
|
||||
}));
|
||||
let denominator_inverses = F::batch_multiplicative_inverse(&denominator_values);
|
||||
quotient_values.extend(
|
||||
(0..num_routed_wires).map(|j| numerator_values[j] * denominator_inverses[j]),
|
||||
);
|
||||
|
||||
// The partial products considered for this iteration of `i`.
|
||||
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
|
||||
// Check the numerator partial products.
|
||||
let mut partial_product_check =
|
||||
check_partial_products("ient_values, current_partial_products, max_degree);
|
||||
// The first checks are of the form `q - n/d` which is a rational function not a polynomial.
|
||||
// We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials.
|
||||
denominator_values
|
||||
.chunks(max_degree)
|
||||
.zip(partial_product_check.iter_mut())
|
||||
.for_each(|(d, q)| {
|
||||
*q *= d.iter().copied().product();
|
||||
});
|
||||
vanishing_partial_products_terms.extend(partial_product_check);
|
||||
|
||||
// The quotient final product is the product of the last `final_num_prod` elements.
|
||||
let quotient: F = current_partial_products[num_prods - final_num_prod..]
|
||||
.iter()
|
||||
.copied()
|
||||
.product();
|
||||
vanishing_v_shift_terms.push(quotient * z_x - z_gz);
|
||||
let partial_product_checks = check_partial_products(
|
||||
&numerator_values,
|
||||
&denominator_values,
|
||||
current_partial_products,
|
||||
z_x,
|
||||
z_gx,
|
||||
max_degree,
|
||||
);
|
||||
vanishing_partial_products_terms.extend(partial_product_checks);
|
||||
|
||||
numerator_values.clear();
|
||||
denominator_values.clear();
|
||||
quotient_values.clear();
|
||||
}
|
||||
|
||||
let vanishing_terms = vanishing_z_1_terms
|
||||
.iter()
|
||||
.chain(vanishing_partial_products_terms.iter())
|
||||
.chain(vanishing_v_shift_terms.iter())
|
||||
.chain(constraint_terms);
|
||||
let res = plonk_common::reduce_with_powers_multi(vanishing_terms, alphas);
|
||||
res_batch.push(res);
|
||||
|
||||
vanishing_z_1_terms.clear();
|
||||
vanishing_partial_products_terms.clear();
|
||||
vanishing_v_shift_terms.clear();
|
||||
}
|
||||
res_batch
|
||||
}
|
||||
@ -334,7 +300,7 @@ pub(crate) fn eval_vanishing_poly_recursively<
|
||||
alphas: &[Target],
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let max_degree = common_data.quotient_degree_factor;
|
||||
let (num_prods, final_num_prod) = common_data.num_partial_products;
|
||||
let (num_prods, _final_num_prod) = common_data.num_partial_products;
|
||||
|
||||
let constraint_terms = with_context!(
|
||||
builder,
|
||||
@ -351,8 +317,6 @@ pub(crate) fn eval_vanishing_poly_recursively<
|
||||
let mut vanishing_z_1_terms = Vec::new();
|
||||
// The terms checking the partial products.
|
||||
let mut vanishing_partial_products_terms = Vec::new();
|
||||
// The Z(x) f'(x) - g'(x) Z(g x) terms.
|
||||
let mut vanishing_v_shift_terms = Vec::new();
|
||||
|
||||
let l1_x = eval_l_1_recursively(builder, common_data.degree(), x, x_pow_deg);
|
||||
|
||||
@ -365,14 +329,13 @@ pub(crate) fn eval_vanishing_poly_recursively<
|
||||
|
||||
for i in 0..common_data.config.num_challenges {
|
||||
let z_x = local_zs[i];
|
||||
let z_gz = next_zs[i];
|
||||
let z_gx = next_zs[i];
|
||||
|
||||
// L_1(x) Z(x) = 0.
|
||||
vanishing_z_1_terms.push(builder.mul_sub_extension(l1_x, z_x, l1_x));
|
||||
|
||||
let mut numerator_values = Vec::new();
|
||||
let mut denominator_values = Vec::new();
|
||||
let mut quotient_values = Vec::new();
|
||||
|
||||
for j in 0..common_data.config.num_routed_wires {
|
||||
let wire_value = vars.local_wires[j];
|
||||
@ -385,44 +348,28 @@ pub(crate) fn eval_vanishing_poly_recursively<
|
||||
let numerator = builder.mul_add_extension(beta_ext, s_ids[j], wire_value_plus_gamma);
|
||||
let denominator =
|
||||
builder.mul_add_extension(beta_ext, s_sigmas[j], wire_value_plus_gamma);
|
||||
let quotient = builder.div_extension(numerator, denominator);
|
||||
|
||||
numerator_values.push(numerator);
|
||||
denominator_values.push(denominator);
|
||||
quotient_values.push(quotient);
|
||||
}
|
||||
|
||||
// The partial products considered for this iteration of `i`.
|
||||
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
|
||||
// Check the quotient partial products.
|
||||
let mut partial_product_check = check_partial_products_recursively(
|
||||
let partial_product_checks = check_partial_products_recursively(
|
||||
builder,
|
||||
"ient_values,
|
||||
&numerator_values,
|
||||
&denominator_values,
|
||||
current_partial_products,
|
||||
z_x,
|
||||
z_gx,
|
||||
max_degree,
|
||||
);
|
||||
// The first checks are of the form `q - n/d` which is a rational function not a polynomial.
|
||||
// We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials.
|
||||
denominator_values
|
||||
.chunks(max_degree)
|
||||
.zip(partial_product_check.iter_mut())
|
||||
.for_each(|(d, q)| {
|
||||
let mut v = d.to_vec();
|
||||
v.push(*q);
|
||||
*q = builder.mul_many_extension(&v);
|
||||
});
|
||||
vanishing_partial_products_terms.extend(partial_product_check);
|
||||
|
||||
// The quotient final product is the product of the last `final_num_prod` elements.
|
||||
let quotient =
|
||||
builder.mul_many_extension(¤t_partial_products[num_prods - final_num_prod..]);
|
||||
vanishing_v_shift_terms.push(builder.mul_sub_extension(quotient, z_x, z_gz));
|
||||
vanishing_partial_products_terms.extend(partial_product_checks);
|
||||
}
|
||||
|
||||
let vanishing_terms = [
|
||||
vanishing_z_1_terms,
|
||||
vanishing_partial_products_terms,
|
||||
vanishing_v_shift_terms,
|
||||
constraint_terms,
|
||||
]
|
||||
.concat();
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
use std::convert::TryInto;
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::field::extension_field::algebra::ExtensionAlgebra;
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
use crate::field::fft::{fft, ifft};
|
||||
use crate::field::field_types::Field;
|
||||
use crate::polynomial::polynomial::PolynomialCoeffs;
|
||||
use crate::util::{log2_ceil, log2_strict};
|
||||
use crate::polynomial::PolynomialCoeffs;
|
||||
use crate::util::log2_ceil;
|
||||
|
||||
impl<F: Field> PolynomialCoeffs<F> {
|
||||
/// Polynomial division.
|
||||
@ -67,63 +66,6 @@ impl<F: Field> PolynomialCoeffs<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Takes a polynomial `a` in coefficient form, and divides it by `Z_H = X^n - 1`.
|
||||
///
|
||||
/// This assumes `Z_H | a`, otherwise result is meaningless.
|
||||
pub(crate) fn divide_by_z_h(&self, n: usize) -> PolynomialCoeffs<F> {
|
||||
let mut a = self.clone();
|
||||
|
||||
// TODO: Is this special case needed?
|
||||
if a.coeffs.iter().all(|p| *p == F::ZERO) {
|
||||
return a;
|
||||
}
|
||||
|
||||
let g = F::MULTIPLICATIVE_GROUP_GENERATOR;
|
||||
let mut g_pow = F::ONE;
|
||||
// Multiply the i-th coefficient of `a` by `g^i`. Then `new_a(w^j) = old_a(g.w^j)`.
|
||||
a.coeffs.iter_mut().for_each(|x| {
|
||||
*x *= g_pow;
|
||||
g_pow *= g;
|
||||
});
|
||||
|
||||
let root = F::primitive_root_of_unity(log2_strict(a.len()));
|
||||
// Equals to the evaluation of `a` on `{g.w^i}`.
|
||||
let mut a_eval = fft(&a);
|
||||
// Compute the denominators `1/(g^n.w^(n*i) - 1)` using batch inversion.
|
||||
let denominator_g = g.exp_u64(n as u64);
|
||||
let root_n = root.exp_u64(n as u64);
|
||||
let mut root_pow = F::ONE;
|
||||
let denominators = (0..a_eval.len())
|
||||
.map(|i| {
|
||||
if i != 0 {
|
||||
root_pow *= root_n;
|
||||
}
|
||||
denominator_g * root_pow - F::ONE
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let denominators_inv = F::batch_multiplicative_inverse(&denominators);
|
||||
// Divide every element of `a_eval` by the corresponding denominator.
|
||||
// Then, `a_eval` is the evaluation of `a/Z_H` on `{g.w^i}`.
|
||||
a_eval
|
||||
.values
|
||||
.iter_mut()
|
||||
.zip(denominators_inv.iter())
|
||||
.for_each(|(x, &d)| {
|
||||
*x *= d;
|
||||
});
|
||||
// `p` is the interpolating polynomial of `a_eval` on `{w^i}`.
|
||||
let mut p = ifft(&a_eval);
|
||||
// We need to scale it by `g^(-i)` to get the interpolating polynomial of `a_eval` on `{g.w^i}`,
|
||||
// a.k.a `a/Z_H`.
|
||||
let g_inv = g.inverse();
|
||||
let mut g_inv_pow = F::ONE;
|
||||
p.coeffs.iter_mut().for_each(|x| {
|
||||
*x *= g_inv_pow;
|
||||
g_inv_pow *= g_inv;
|
||||
});
|
||||
p
|
||||
}
|
||||
|
||||
/// Let `self=p(X)`, this returns `(p(X)-p(z))/(X-z)` and `p(z)`.
|
||||
/// See https://en.wikipedia.org/wiki/Horner%27s_method
|
||||
pub(crate) fn divide_by_linear(&self, z: F) -> (PolynomialCoeffs<F>, F) {
|
||||
@ -187,35 +129,7 @@ mod tests {
|
||||
use crate::field::extension_field::quartic::QuarticExtension;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::polynomial::polynomial::PolynomialCoeffs;
|
||||
|
||||
#[test]
|
||||
fn zero_div_z_h() {
|
||||
type F = GoldilocksField;
|
||||
let zero = PolynomialCoeffs::<F>::zero(16);
|
||||
let quotient = zero.divide_by_z_h(4);
|
||||
assert_eq!(quotient, zero);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn division_by_z_h() {
|
||||
type F = GoldilocksField;
|
||||
let zero = F::ZERO;
|
||||
let three = F::from_canonical_u64(3);
|
||||
let four = F::from_canonical_u64(4);
|
||||
let five = F::from_canonical_u64(5);
|
||||
let six = F::from_canonical_u64(6);
|
||||
|
||||
// a(x) = Z_4(x) q(x), where
|
||||
// a(x) = 3 x^7 + 4 x^6 + 5 x^5 + 6 x^4 - 3 x^3 - 4 x^2 - 5 x - 6
|
||||
// Z_4(x) = x^4 - 1
|
||||
// q(x) = 3 x^3 + 4 x^2 + 5 x + 6
|
||||
let a = PolynomialCoeffs::new(vec![-six, -five, -four, -three, six, five, four, three]);
|
||||
let q = PolynomialCoeffs::new(vec![six, five, four, three, zero, zero, zero, zero]);
|
||||
|
||||
let computed_q = a.divide_by_z_h(4);
|
||||
assert_eq!(computed_q, q);
|
||||
}
|
||||
use crate::polynomial::PolynomialCoeffs;
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
|
||||
@ -1,2 +1,616 @@
|
||||
pub(crate) mod division;
|
||||
pub mod polynomial;
|
||||
|
||||
use std::cmp::max;
|
||||
use std::iter::Sum;
|
||||
use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign};
|
||||
|
||||
use anyhow::{ensure, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::field::extension_field::{Extendable, FieldExtension};
|
||||
use crate::field::fft::{fft, fft_with_options, ifft, FftRootTable};
|
||||
use crate::field::field_types::Field;
|
||||
use crate::util::log2_strict;
|
||||
|
||||
/// A polynomial in point-value form.
|
||||
///
|
||||
/// The points are implicitly `g^i`, where `g` generates the subgroup whose size equals the number
|
||||
/// of points.
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct PolynomialValues<F: Field> {
|
||||
pub values: Vec<F>,
|
||||
}
|
||||
|
||||
impl<F: Field> PolynomialValues<F> {
|
||||
pub fn new(values: Vec<F>) -> Self {
|
||||
PolynomialValues { values }
|
||||
}
|
||||
|
||||
/// The number of values stored.
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
self.values.len()
|
||||
}
|
||||
|
||||
pub fn ifft(&self) -> PolynomialCoeffs<F> {
|
||||
ifft(self)
|
||||
}
|
||||
|
||||
/// Returns the polynomial whose evaluation on the coset `shift*H` is `self`.
|
||||
pub fn coset_ifft(&self, shift: F) -> PolynomialCoeffs<F> {
|
||||
let mut shifted_coeffs = self.ifft();
|
||||
shifted_coeffs
|
||||
.coeffs
|
||||
.iter_mut()
|
||||
.zip(shift.inverse().powers())
|
||||
.for_each(|(c, r)| {
|
||||
*c *= r;
|
||||
});
|
||||
shifted_coeffs
|
||||
}
|
||||
|
||||
pub fn lde_multiple(polys: Vec<Self>, rate_bits: usize) -> Vec<Self> {
|
||||
polys.into_iter().map(|p| p.lde(rate_bits)).collect()
|
||||
}
|
||||
|
||||
pub fn lde(&self, rate_bits: usize) -> Self {
|
||||
let coeffs = ifft(self).lde(rate_bits);
|
||||
fft_with_options(&coeffs, Some(rate_bits), None)
|
||||
}
|
||||
|
||||
pub fn degree(&self) -> usize {
|
||||
self.degree_plus_one()
|
||||
.checked_sub(1)
|
||||
.expect("deg(0) is undefined")
|
||||
}
|
||||
|
||||
pub fn degree_plus_one(&self) -> usize {
|
||||
self.ifft().degree_plus_one()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> From<Vec<F>> for PolynomialValues<F> {
|
||||
fn from(values: Vec<F>) -> Self {
|
||||
Self::new(values)
|
||||
}
|
||||
}
|
||||
|
||||
/// A polynomial in coefficient form.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(bound = "")]
|
||||
pub struct PolynomialCoeffs<F: Field> {
|
||||
pub(crate) coeffs: Vec<F>,
|
||||
}
|
||||
|
||||
impl<F: Field> PolynomialCoeffs<F> {
|
||||
pub fn new(coeffs: Vec<F>) -> Self {
|
||||
PolynomialCoeffs { coeffs }
|
||||
}
|
||||
|
||||
pub(crate) fn empty() -> Self {
|
||||
Self::new(Vec::new())
|
||||
}
|
||||
|
||||
pub(crate) fn zero(len: usize) -> Self {
|
||||
Self::new(vec![F::ZERO; len])
|
||||
}
|
||||
|
||||
pub(crate) fn is_zero(&self) -> bool {
|
||||
self.coeffs.iter().all(|x| x.is_zero())
|
||||
}
|
||||
|
||||
/// The number of coefficients. This does not filter out any zero coefficients, so it is not
|
||||
/// necessarily related to the degree.
|
||||
pub fn len(&self) -> usize {
|
||||
self.coeffs.len()
|
||||
}
|
||||
|
||||
pub fn log_len(&self) -> usize {
|
||||
log2_strict(self.len())
|
||||
}
|
||||
|
||||
pub(crate) fn chunks(&self, chunk_size: usize) -> Vec<Self> {
|
||||
self.coeffs
|
||||
.chunks(chunk_size)
|
||||
.map(|chunk| PolynomialCoeffs::new(chunk.to_vec()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn eval(&self, x: F) -> F {
|
||||
self.coeffs
|
||||
.iter()
|
||||
.rev()
|
||||
.fold(F::ZERO, |acc, &c| acc * x + c)
|
||||
}
|
||||
|
||||
/// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1.
|
||||
pub fn eval_with_powers(&self, powers: &[F]) -> F {
|
||||
debug_assert_eq!(self.coeffs.len(), powers.len() + 1);
|
||||
let acc = self.coeffs[0];
|
||||
self.coeffs[1..]
|
||||
.iter()
|
||||
.zip(powers)
|
||||
.fold(acc, |acc, (&x, &c)| acc + c * x)
|
||||
}
|
||||
|
||||
pub fn eval_base<const D: usize>(&self, x: F::BaseField) -> F
|
||||
where
|
||||
F: FieldExtension<D>,
|
||||
{
|
||||
self.coeffs
|
||||
.iter()
|
||||
.rev()
|
||||
.fold(F::ZERO, |acc, &c| acc.scalar_mul(x) + c)
|
||||
}
|
||||
|
||||
/// Evaluate the polynomial at a point given its powers. The first power is the point itself, not 1.
|
||||
pub fn eval_base_with_powers<const D: usize>(&self, powers: &[F::BaseField]) -> F
|
||||
where
|
||||
F: FieldExtension<D>,
|
||||
{
|
||||
debug_assert_eq!(self.coeffs.len(), powers.len() + 1);
|
||||
let acc = self.coeffs[0];
|
||||
self.coeffs[1..]
|
||||
.iter()
|
||||
.zip(powers)
|
||||
.fold(acc, |acc, (&x, &c)| acc + x.scalar_mul(c))
|
||||
}
|
||||
|
||||
pub fn lde_multiple(polys: Vec<&Self>, rate_bits: usize) -> Vec<Self> {
|
||||
polys.into_iter().map(|p| p.lde(rate_bits)).collect()
|
||||
}
|
||||
|
||||
pub fn lde(&self, rate_bits: usize) -> Self {
|
||||
self.padded(self.len() << rate_bits)
|
||||
}
|
||||
|
||||
pub(crate) fn pad(&mut self, new_len: usize) -> Result<()> {
|
||||
ensure!(
|
||||
new_len >= self.len(),
|
||||
"Trying to pad a polynomial of length {} to a length of {}.",
|
||||
self.len(),
|
||||
new_len
|
||||
);
|
||||
self.coeffs.resize(new_len, F::ZERO);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn padded(&self, new_len: usize) -> Self {
|
||||
let mut poly = self.clone();
|
||||
poly.pad(new_len).unwrap();
|
||||
poly
|
||||
}
|
||||
|
||||
/// Removes leading zero coefficients.
|
||||
pub fn trim(&mut self) {
|
||||
self.coeffs.truncate(self.degree_plus_one());
|
||||
}
|
||||
|
||||
/// Removes leading zero coefficients.
|
||||
pub fn trimmed(&self) -> Self {
|
||||
let coeffs = self.coeffs[..self.degree_plus_one()].to_vec();
|
||||
Self { coeffs }
|
||||
}
|
||||
|
||||
/// Degree of the polynomial + 1, or 0 for a polynomial with no non-zero coefficients.
|
||||
pub(crate) fn degree_plus_one(&self) -> usize {
|
||||
(0usize..self.len())
|
||||
.rev()
|
||||
.find(|&i| self.coeffs[i].is_nonzero())
|
||||
.map_or(0, |i| i + 1)
|
||||
}
|
||||
|
||||
/// Leading coefficient.
|
||||
pub fn lead(&self) -> F {
|
||||
self.coeffs
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|x| x.is_nonzero())
|
||||
.map_or(F::ZERO, |x| *x)
|
||||
}
|
||||
|
||||
/// Reverse the order of the coefficients, not taking into account the leading zero coefficients.
|
||||
pub(crate) fn rev(&self) -> Self {
|
||||
Self::new(self.trimmed().coeffs.into_iter().rev().collect())
|
||||
}
|
||||
|
||||
pub fn fft(&self) -> PolynomialValues<F> {
|
||||
fft(self)
|
||||
}
|
||||
|
||||
pub fn fft_with_options(
|
||||
&self,
|
||||
zero_factor: Option<usize>,
|
||||
root_table: Option<&FftRootTable<F>>,
|
||||
) -> PolynomialValues<F> {
|
||||
fft_with_options(self, zero_factor, root_table)
|
||||
}
|
||||
|
||||
/// Returns the evaluation of the polynomial on the coset `shift*H`.
|
||||
pub fn coset_fft(&self, shift: F) -> PolynomialValues<F> {
|
||||
self.coset_fft_with_options(shift, None, None)
|
||||
}
|
||||
|
||||
/// Returns the evaluation of the polynomial on the coset `shift*H`.
|
||||
pub fn coset_fft_with_options(
|
||||
&self,
|
||||
shift: F,
|
||||
zero_factor: Option<usize>,
|
||||
root_table: Option<&FftRootTable<F>>,
|
||||
) -> PolynomialValues<F> {
|
||||
let modified_poly: Self = shift
|
||||
.powers()
|
||||
.zip(&self.coeffs)
|
||||
.map(|(r, &c)| r * c)
|
||||
.collect::<Vec<_>>()
|
||||
.into();
|
||||
modified_poly.fft_with_options(zero_factor, root_table)
|
||||
}
|
||||
|
||||
pub fn to_extension<const D: usize>(&self) -> PolynomialCoeffs<F::Extension>
|
||||
where
|
||||
F: Extendable<D>,
|
||||
{
|
||||
PolynomialCoeffs::new(self.coeffs.iter().map(|&c| c.into()).collect())
|
||||
}
|
||||
|
||||
pub fn mul_extension<const D: usize>(&self, rhs: F::Extension) -> PolynomialCoeffs<F::Extension>
|
||||
where
|
||||
F: Extendable<D>,
|
||||
{
|
||||
PolynomialCoeffs::new(self.coeffs.iter().map(|&c| rhs.scalar_mul(c)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> PartialEq for PolynomialCoeffs<F> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
let max_terms = self.coeffs.len().max(other.coeffs.len());
|
||||
for i in 0..max_terms {
|
||||
let self_i = self.coeffs.get(i).cloned().unwrap_or(F::ZERO);
|
||||
let other_i = other.coeffs.get(i).cloned().unwrap_or(F::ZERO);
|
||||
if self_i != other_i {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Eq for PolynomialCoeffs<F> {}
|
||||
|
||||
impl<F: Field> From<Vec<F>> for PolynomialCoeffs<F> {
|
||||
fn from(coeffs: Vec<F>) -> Self {
|
||||
Self::new(coeffs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Add for &PolynomialCoeffs<F> {
|
||||
type Output = PolynomialCoeffs<F>;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
let len = max(self.len(), rhs.len());
|
||||
let a = self.padded(len).coeffs;
|
||||
let b = rhs.padded(len).coeffs;
|
||||
let coeffs = a.into_iter().zip(b).map(|(x, y)| x + y).collect();
|
||||
PolynomialCoeffs::new(coeffs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Sum for PolynomialCoeffs<F> {
|
||||
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
|
||||
iter.fold(Self::empty(), |acc, p| &acc + &p)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Sub for &PolynomialCoeffs<F> {
|
||||
type Output = PolynomialCoeffs<F>;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
let len = max(self.len(), rhs.len());
|
||||
let mut coeffs = self.padded(len).coeffs;
|
||||
for (i, &c) in rhs.coeffs.iter().enumerate() {
|
||||
coeffs[i] -= c;
|
||||
}
|
||||
PolynomialCoeffs::new(coeffs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> AddAssign for PolynomialCoeffs<F> {
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
let len = max(self.len(), rhs.len());
|
||||
self.coeffs.resize(len, F::ZERO);
|
||||
for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) {
|
||||
*l += r;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> AddAssign<&Self> for PolynomialCoeffs<F> {
|
||||
fn add_assign(&mut self, rhs: &Self) {
|
||||
let len = max(self.len(), rhs.len());
|
||||
self.coeffs.resize(len, F::ZERO);
|
||||
for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) {
|
||||
*l += r;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> SubAssign for PolynomialCoeffs<F> {
|
||||
fn sub_assign(&mut self, rhs: Self) {
|
||||
let len = max(self.len(), rhs.len());
|
||||
self.coeffs.resize(len, F::ZERO);
|
||||
for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) {
|
||||
*l -= r;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> SubAssign<&Self> for PolynomialCoeffs<F> {
|
||||
fn sub_assign(&mut self, rhs: &Self) {
|
||||
let len = max(self.len(), rhs.len());
|
||||
self.coeffs.resize(len, F::ZERO);
|
||||
for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) {
|
||||
*l -= r;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Mul<F> for &PolynomialCoeffs<F> {
|
||||
type Output = PolynomialCoeffs<F>;
|
||||
|
||||
fn mul(self, rhs: F) -> Self::Output {
|
||||
let coeffs = self.coeffs.iter().map(|&x| rhs * x).collect();
|
||||
PolynomialCoeffs::new(coeffs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> MulAssign<F> for PolynomialCoeffs<F> {
|
||||
fn mul_assign(&mut self, rhs: F) {
|
||||
self.coeffs.iter_mut().for_each(|x| *x *= rhs);
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Mul for &PolynomialCoeffs<F> {
|
||||
type Output = PolynomialCoeffs<F>;
|
||||
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
let new_len = (self.len() + rhs.len()).next_power_of_two();
|
||||
let a = self.padded(new_len);
|
||||
let b = rhs.padded(new_len);
|
||||
let a_evals = a.fft();
|
||||
let b_evals = b.fft();
|
||||
|
||||
let mul_evals: Vec<F> = a_evals
|
||||
.values
|
||||
.into_iter()
|
||||
.zip(b_evals.values)
|
||||
.map(|(pa, pb)| pa * pb)
|
||||
.collect();
|
||||
ifft(&mul_evals.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Instant;
|
||||
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
use super::*;
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
|
||||
#[test]
|
||||
fn test_trimmed() {
|
||||
type F = GoldilocksField;
|
||||
|
||||
assert_eq!(
|
||||
PolynomialCoeffs::<F> { coeffs: vec![] }.trimmed(),
|
||||
PolynomialCoeffs::<F> { coeffs: vec![] }
|
||||
);
|
||||
assert_eq!(
|
||||
PolynomialCoeffs::<F> {
|
||||
coeffs: vec![F::ZERO]
|
||||
}
|
||||
.trimmed(),
|
||||
PolynomialCoeffs::<F> { coeffs: vec![] }
|
||||
);
|
||||
assert_eq!(
|
||||
PolynomialCoeffs::<F> {
|
||||
coeffs: vec![F::ONE, F::TWO, F::ZERO, F::ZERO]
|
||||
}
|
||||
.trimmed(),
|
||||
PolynomialCoeffs::<F> {
|
||||
coeffs: vec![F::ONE, F::TWO]
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coset_fft() {
|
||||
type F = GoldilocksField;
|
||||
|
||||
let k = 8;
|
||||
let n = 1 << k;
|
||||
let poly = PolynomialCoeffs::new(F::rand_vec(n));
|
||||
let shift = F::rand();
|
||||
let coset_evals = poly.coset_fft(shift).values;
|
||||
|
||||
let generator = F::primitive_root_of_unity(k);
|
||||
let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n)
|
||||
.into_iter()
|
||||
.map(|x| poly.eval(x))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(coset_evals, naive_coset_evals);
|
||||
|
||||
let ifft_coeffs = PolynomialValues::new(coset_evals).coset_ifft(shift);
|
||||
assert_eq!(poly, ifft_coeffs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coset_ifft() {
|
||||
type F = GoldilocksField;
|
||||
|
||||
let k = 8;
|
||||
let n = 1 << k;
|
||||
let evals = PolynomialValues::new(F::rand_vec(n));
|
||||
let shift = F::rand();
|
||||
let coeffs = evals.coset_ifft(shift);
|
||||
|
||||
let generator = F::primitive_root_of_unity(k);
|
||||
let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n)
|
||||
.into_iter()
|
||||
.map(|x| coeffs.eval(x))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(evals, naive_coset_evals.into());
|
||||
|
||||
let fft_evals = coeffs.coset_fft(shift);
|
||||
assert_eq!(evals, fft_evals);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_polynomial_multiplication() {
|
||||
type F = GoldilocksField;
|
||||
let mut rng = thread_rng();
|
||||
let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000));
|
||||
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
|
||||
let b = PolynomialCoeffs::new(F::rand_vec(b_deg));
|
||||
let m1 = &a * &b;
|
||||
let m2 = &a * &b;
|
||||
for _ in 0..1000 {
|
||||
let x = F::rand();
|
||||
assert_eq!(m1.eval(x), a.eval(x) * b.eval(x));
|
||||
assert_eq!(m2.eval(x), a.eval(x) * b.eval(x));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inv_mod_xn() {
|
||||
type F = GoldilocksField;
|
||||
let mut rng = thread_rng();
|
||||
let a_deg = rng.gen_range(1..1_000);
|
||||
let n = rng.gen_range(1..1_000);
|
||||
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
|
||||
let b = a.inv_mod_xn(n);
|
||||
let mut m = &a * &b;
|
||||
m.coeffs.drain(n..);
|
||||
m.trim();
|
||||
assert_eq!(
|
||||
m,
|
||||
PolynomialCoeffs::new(vec![F::ONE]),
|
||||
"a: {:#?}, b:{:#?}, n:{:#?}, m:{:#?}",
|
||||
a,
|
||||
b,
|
||||
n,
|
||||
m
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_polynomial_long_division() {
|
||||
type F = GoldilocksField;
|
||||
let mut rng = thread_rng();
|
||||
let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000));
|
||||
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
|
||||
let b = PolynomialCoeffs::new(F::rand_vec(b_deg));
|
||||
let (q, r) = a.div_rem_long_division(&b);
|
||||
for _ in 0..1000 {
|
||||
let x = F::rand();
|
||||
assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_polynomial_division() {
|
||||
type F = GoldilocksField;
|
||||
let mut rng = thread_rng();
|
||||
let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000));
|
||||
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
|
||||
let b = PolynomialCoeffs::new(F::rand_vec(b_deg));
|
||||
let (q, r) = a.div_rem(&b);
|
||||
for _ in 0..1000 {
|
||||
let x = F::rand();
|
||||
assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_polynomial_division_by_constant() {
|
||||
type F = GoldilocksField;
|
||||
let mut rng = thread_rng();
|
||||
let a_deg = rng.gen_range(1..10_000);
|
||||
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
|
||||
let b = PolynomialCoeffs::from(vec![F::rand()]);
|
||||
let (q, r) = a.div_rem(&b);
|
||||
for _ in 0..1000 {
|
||||
let x = F::rand();
|
||||
assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x));
|
||||
}
|
||||
}
|
||||
|
||||
// Test to see which polynomial division method is faster for divisions of the type
|
||||
// `(X^n - 1)/(X - a)
|
||||
#[test]
|
||||
fn test_division_linear() {
|
||||
type F = GoldilocksField;
|
||||
let mut rng = thread_rng();
|
||||
let l = 14;
|
||||
let n = 1 << l;
|
||||
let g = F::primitive_root_of_unity(l);
|
||||
let xn_minus_one = {
|
||||
let mut xn_min_one_vec = vec![F::ZERO; n + 1];
|
||||
xn_min_one_vec[n] = F::ONE;
|
||||
xn_min_one_vec[0] = F::NEG_ONE;
|
||||
PolynomialCoeffs::new(xn_min_one_vec)
|
||||
};
|
||||
|
||||
let a = g.exp_u64(rng.gen_range(0..(n as u64)));
|
||||
let denom = PolynomialCoeffs::new(vec![-a, F::ONE]);
|
||||
let now = Instant::now();
|
||||
xn_minus_one.div_rem(&denom);
|
||||
println!("Division time: {:?}", now.elapsed());
|
||||
let now = Instant::now();
|
||||
xn_minus_one.div_rem_long_division(&denom);
|
||||
println!("Division time: {:?}", now.elapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eq() {
|
||||
type F = GoldilocksField;
|
||||
assert_eq!(
|
||||
PolynomialCoeffs::<F>::new(vec![]),
|
||||
PolynomialCoeffs::new(vec![])
|
||||
);
|
||||
assert_eq!(
|
||||
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
|
||||
PolynomialCoeffs::new(vec![F::ZERO])
|
||||
);
|
||||
assert_eq!(
|
||||
PolynomialCoeffs::<F>::new(vec![]),
|
||||
PolynomialCoeffs::new(vec![F::ZERO])
|
||||
);
|
||||
assert_eq!(
|
||||
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
|
||||
PolynomialCoeffs::new(vec![])
|
||||
);
|
||||
assert_eq!(
|
||||
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
|
||||
PolynomialCoeffs::new(vec![F::ZERO, F::ZERO])
|
||||
);
|
||||
assert_eq!(
|
||||
PolynomialCoeffs::<F>::new(vec![F::ONE]),
|
||||
PolynomialCoeffs::new(vec![F::ONE, F::ZERO])
|
||||
);
|
||||
assert_ne!(
|
||||
PolynomialCoeffs::<F>::new(vec![]),
|
||||
PolynomialCoeffs::new(vec![F::ONE])
|
||||
);
|
||||
assert_ne!(
|
||||
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
|
||||
PolynomialCoeffs::new(vec![F::ZERO, F::ONE])
|
||||
);
|
||||
assert_ne!(
|
||||
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
|
||||
PolynomialCoeffs::new(vec![F::ONE, F::ZERO])
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,635 +0,0 @@
|
||||
use std::cmp::max;
|
||||
use std::iter::Sum;
|
||||
use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign};
|
||||
|
||||
use anyhow::{ensure, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::field::extension_field::{Extendable, FieldExtension};
|
||||
use crate::field::fft::{fft, fft_with_options, ifft, FftRootTable};
|
||||
use crate::field::field_types::Field;
|
||||
use crate::util::log2_strict;
|
||||
|
||||
/// A polynomial in point-value form.
|
||||
///
|
||||
/// The points are implicitly `g^i`, where `g` generates the subgroup whose size equals the number
|
||||
/// of points.
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub struct PolynomialValues<F: Field> {
|
||||
pub values: Vec<F>,
|
||||
}
|
||||
|
||||
impl<F: Field> PolynomialValues<F> {
|
||||
pub fn new(values: Vec<F>) -> Self {
|
||||
PolynomialValues { values }
|
||||
}
|
||||
|
||||
pub(crate) fn zero(len: usize) -> Self {
|
||||
Self::new(vec![F::ZERO; len])
|
||||
}
|
||||
|
||||
/// The number of values stored.
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
self.values.len()
|
||||
}
|
||||
|
||||
pub fn ifft(&self) -> PolynomialCoeffs<F> {
|
||||
ifft(self)
|
||||
}
|
||||
|
||||
/// Returns the polynomial whose evaluation on the coset `shift*H` is `self`.
|
||||
pub fn coset_ifft(&self, shift: F) -> PolynomialCoeffs<F> {
|
||||
let mut shifted_coeffs = self.ifft();
|
||||
shifted_coeffs
|
||||
.coeffs
|
||||
.iter_mut()
|
||||
.zip(shift.inverse().powers())
|
||||
.for_each(|(c, r)| {
|
||||
*c *= r;
|
||||
});
|
||||
shifted_coeffs
|
||||
}
|
||||
|
||||
pub fn lde_multiple(polys: Vec<Self>, rate_bits: usize) -> Vec<Self> {
|
||||
polys.into_iter().map(|p| p.lde(rate_bits)).collect()
|
||||
}
|
||||
|
||||
pub fn lde(&self, rate_bits: usize) -> Self {
|
||||
let coeffs = ifft(self).lde(rate_bits);
|
||||
fft_with_options(&coeffs, Some(rate_bits), None)
|
||||
}
|
||||
|
||||
pub fn degree(&self) -> usize {
|
||||
self.degree_plus_one()
|
||||
.checked_sub(1)
|
||||
.expect("deg(0) is undefined")
|
||||
}
|
||||
|
||||
pub fn degree_plus_one(&self) -> usize {
|
||||
self.ifft().degree_plus_one()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> From<Vec<F>> for PolynomialValues<F> {
|
||||
fn from(values: Vec<F>) -> Self {
|
||||
Self::new(values)
|
||||
}
|
||||
}
|
||||
|
||||
/// A polynomial in coefficient form.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(bound = "")]
|
||||
pub struct PolynomialCoeffs<F: Field> {
|
||||
pub(crate) coeffs: Vec<F>,
|
||||
}
|
||||
|
||||
impl<F: Field> PolynomialCoeffs<F> {
|
||||
pub fn new(coeffs: Vec<F>) -> Self {
|
||||
PolynomialCoeffs { coeffs }
|
||||
}
|
||||
|
||||
/// Create a new polynomial with its coefficient list padded to the next power of two.
|
||||
pub(crate) fn new_padded(mut coeffs: Vec<F>) -> Self {
|
||||
while !coeffs.len().is_power_of_two() {
|
||||
coeffs.push(F::ZERO);
|
||||
}
|
||||
PolynomialCoeffs { coeffs }
|
||||
}
|
||||
|
||||
pub(crate) fn empty() -> Self {
|
||||
Self::new(Vec::new())
|
||||
}
|
||||
|
||||
pub(crate) fn zero(len: usize) -> Self {
|
||||
Self::new(vec![F::ZERO; len])
|
||||
}
|
||||
|
||||
pub(crate) fn one() -> Self {
|
||||
Self::new(vec![F::ONE])
|
||||
}
|
||||
|
||||
pub(crate) fn is_zero(&self) -> bool {
|
||||
self.coeffs.iter().all(|x| x.is_zero())
|
||||
}
|
||||
|
||||
/// The number of coefficients. This does not filter out any zero coefficients, so it is not
|
||||
/// necessarily related to the degree.
|
||||
pub fn len(&self) -> usize {
|
||||
self.coeffs.len()
|
||||
}
|
||||
|
||||
pub fn log_len(&self) -> usize {
|
||||
log2_strict(self.len())
|
||||
}
|
||||
|
||||
pub(crate) fn chunks(&self, chunk_size: usize) -> Vec<Self> {
|
||||
self.coeffs
|
||||
.chunks(chunk_size)
|
||||
.map(|chunk| PolynomialCoeffs::new(chunk.to_vec()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn eval(&self, x: F) -> F {
|
||||
self.coeffs
|
||||
.iter()
|
||||
.rev()
|
||||
.fold(F::ZERO, |acc, &c| acc * x + c)
|
||||
}
|
||||
|
||||
pub fn eval_base<const D: usize>(&self, x: F::BaseField) -> F
|
||||
where
|
||||
F: FieldExtension<D>,
|
||||
{
|
||||
self.coeffs
|
||||
.iter()
|
||||
.rev()
|
||||
.fold(F::ZERO, |acc, &c| acc.scalar_mul(x) + c)
|
||||
}
|
||||
|
||||
pub fn lde_multiple(polys: Vec<&Self>, rate_bits: usize) -> Vec<Self> {
|
||||
polys.into_iter().map(|p| p.lde(rate_bits)).collect()
|
||||
}
|
||||
|
||||
pub fn lde(&self, rate_bits: usize) -> Self {
|
||||
self.padded(self.len() << rate_bits)
|
||||
}
|
||||
|
||||
pub(crate) fn pad(&mut self, new_len: usize) -> Result<()> {
|
||||
ensure!(
|
||||
new_len >= self.len(),
|
||||
"Trying to pad a polynomial of length {} to a length of {}.",
|
||||
self.len(),
|
||||
new_len
|
||||
);
|
||||
self.coeffs.resize(new_len, F::ZERO);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn padded(&self, new_len: usize) -> Self {
|
||||
let mut poly = self.clone();
|
||||
poly.pad(new_len).unwrap();
|
||||
poly
|
||||
}
|
||||
|
||||
/// Removes leading zero coefficients.
|
||||
pub fn trim(&mut self) {
|
||||
self.coeffs.truncate(self.degree_plus_one());
|
||||
}
|
||||
|
||||
/// Removes leading zero coefficients.
|
||||
pub fn trimmed(&self) -> Self {
|
||||
let coeffs = self.coeffs[..self.degree_plus_one()].to_vec();
|
||||
Self { coeffs }
|
||||
}
|
||||
|
||||
/// Degree of the polynomial + 1, or 0 for a polynomial with no non-zero coefficients.
|
||||
pub(crate) fn degree_plus_one(&self) -> usize {
|
||||
(0usize..self.len())
|
||||
.rev()
|
||||
.find(|&i| self.coeffs[i].is_nonzero())
|
||||
.map_or(0, |i| i + 1)
|
||||
}
|
||||
|
||||
/// Leading coefficient.
|
||||
pub fn lead(&self) -> F {
|
||||
self.coeffs
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|x| x.is_nonzero())
|
||||
.map_or(F::ZERO, |x| *x)
|
||||
}
|
||||
|
||||
/// Reverse the order of the coefficients, not taking into account the leading zero coefficients.
|
||||
pub(crate) fn rev(&self) -> Self {
|
||||
Self::new(self.trimmed().coeffs.into_iter().rev().collect())
|
||||
}
|
||||
|
||||
pub fn fft(&self) -> PolynomialValues<F> {
|
||||
fft(self)
|
||||
}
|
||||
|
||||
pub fn fft_with_options(
|
||||
&self,
|
||||
zero_factor: Option<usize>,
|
||||
root_table: Option<&FftRootTable<F>>,
|
||||
) -> PolynomialValues<F> {
|
||||
fft_with_options(self, zero_factor, root_table)
|
||||
}
|
||||
|
||||
/// Returns the evaluation of the polynomial on the coset `shift*H`.
|
||||
pub fn coset_fft(&self, shift: F) -> PolynomialValues<F> {
|
||||
self.coset_fft_with_options(shift, None, None)
|
||||
}
|
||||
|
||||
/// Returns the evaluation of the polynomial on the coset `shift*H`.
|
||||
pub fn coset_fft_with_options(
|
||||
&self,
|
||||
shift: F,
|
||||
zero_factor: Option<usize>,
|
||||
root_table: Option<&FftRootTable<F>>,
|
||||
) -> PolynomialValues<F> {
|
||||
let modified_poly: Self = shift
|
||||
.powers()
|
||||
.zip(&self.coeffs)
|
||||
.map(|(r, &c)| r * c)
|
||||
.collect::<Vec<_>>()
|
||||
.into();
|
||||
modified_poly.fft_with_options(zero_factor, root_table)
|
||||
}
|
||||
|
||||
pub fn to_extension<const D: usize>(&self) -> PolynomialCoeffs<F::Extension>
|
||||
where
|
||||
F: Extendable<D>,
|
||||
{
|
||||
PolynomialCoeffs::new(self.coeffs.iter().map(|&c| c.into()).collect())
|
||||
}
|
||||
|
||||
pub fn mul_extension<const D: usize>(&self, rhs: F::Extension) -> PolynomialCoeffs<F::Extension>
|
||||
where
|
||||
F: Extendable<D>,
|
||||
{
|
||||
PolynomialCoeffs::new(self.coeffs.iter().map(|&c| rhs.scalar_mul(c)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> PartialEq for PolynomialCoeffs<F> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
let max_terms = self.coeffs.len().max(other.coeffs.len());
|
||||
for i in 0..max_terms {
|
||||
let self_i = self.coeffs.get(i).cloned().unwrap_or(F::ZERO);
|
||||
let other_i = other.coeffs.get(i).cloned().unwrap_or(F::ZERO);
|
||||
if self_i != other_i {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Eq for PolynomialCoeffs<F> {}
|
||||
|
||||
impl<F: Field> From<Vec<F>> for PolynomialCoeffs<F> {
|
||||
fn from(coeffs: Vec<F>) -> Self {
|
||||
Self::new(coeffs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Add for &PolynomialCoeffs<F> {
|
||||
type Output = PolynomialCoeffs<F>;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
let len = max(self.len(), rhs.len());
|
||||
let a = self.padded(len).coeffs;
|
||||
let b = rhs.padded(len).coeffs;
|
||||
let coeffs = a.into_iter().zip(b).map(|(x, y)| x + y).collect();
|
||||
PolynomialCoeffs::new(coeffs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Sum for PolynomialCoeffs<F> {
|
||||
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
|
||||
iter.fold(Self::empty(), |acc, p| &acc + &p)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Sub for &PolynomialCoeffs<F> {
|
||||
type Output = PolynomialCoeffs<F>;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
let len = max(self.len(), rhs.len());
|
||||
let mut coeffs = self.padded(len).coeffs;
|
||||
for (i, &c) in rhs.coeffs.iter().enumerate() {
|
||||
coeffs[i] -= c;
|
||||
}
|
||||
PolynomialCoeffs::new(coeffs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> AddAssign for PolynomialCoeffs<F> {
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
let len = max(self.len(), rhs.len());
|
||||
self.coeffs.resize(len, F::ZERO);
|
||||
for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) {
|
||||
*l += r;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> AddAssign<&Self> for PolynomialCoeffs<F> {
|
||||
fn add_assign(&mut self, rhs: &Self) {
|
||||
let len = max(self.len(), rhs.len());
|
||||
self.coeffs.resize(len, F::ZERO);
|
||||
for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) {
|
||||
*l += r;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> SubAssign for PolynomialCoeffs<F> {
|
||||
fn sub_assign(&mut self, rhs: Self) {
|
||||
let len = max(self.len(), rhs.len());
|
||||
self.coeffs.resize(len, F::ZERO);
|
||||
for (l, r) in self.coeffs.iter_mut().zip(rhs.coeffs) {
|
||||
*l -= r;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> SubAssign<&Self> for PolynomialCoeffs<F> {
|
||||
fn sub_assign(&mut self, rhs: &Self) {
|
||||
let len = max(self.len(), rhs.len());
|
||||
self.coeffs.resize(len, F::ZERO);
|
||||
for (l, &r) in self.coeffs.iter_mut().zip(&rhs.coeffs) {
|
||||
*l -= r;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Mul<F> for &PolynomialCoeffs<F> {
|
||||
type Output = PolynomialCoeffs<F>;
|
||||
|
||||
fn mul(self, rhs: F) -> Self::Output {
|
||||
let coeffs = self.coeffs.iter().map(|&x| rhs * x).collect();
|
||||
PolynomialCoeffs::new(coeffs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> MulAssign<F> for PolynomialCoeffs<F> {
|
||||
fn mul_assign(&mut self, rhs: F) {
|
||||
self.coeffs.iter_mut().for_each(|x| *x *= rhs);
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Mul for &PolynomialCoeffs<F> {
|
||||
type Output = PolynomialCoeffs<F>;
|
||||
|
||||
#[allow(clippy::suspicious_arithmetic_impl)]
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
let new_len = (self.len() + rhs.len()).next_power_of_two();
|
||||
let a = self.padded(new_len);
|
||||
let b = rhs.padded(new_len);
|
||||
let a_evals = a.fft();
|
||||
let b_evals = b.fft();
|
||||
|
||||
let mul_evals: Vec<F> = a_evals
|
||||
.values
|
||||
.into_iter()
|
||||
.zip(b_evals.values)
|
||||
.map(|(pa, pb)| pa * pb)
|
||||
.collect();
|
||||
ifft(&mul_evals.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Instant;
|
||||
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
use super::*;
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
|
||||
#[test]
|
||||
fn test_trimmed() {
|
||||
type F = GoldilocksField;
|
||||
|
||||
assert_eq!(
|
||||
PolynomialCoeffs::<F> { coeffs: vec![] }.trimmed(),
|
||||
PolynomialCoeffs::<F> { coeffs: vec![] }
|
||||
);
|
||||
assert_eq!(
|
||||
PolynomialCoeffs::<F> {
|
||||
coeffs: vec![F::ZERO]
|
||||
}
|
||||
.trimmed(),
|
||||
PolynomialCoeffs::<F> { coeffs: vec![] }
|
||||
);
|
||||
assert_eq!(
|
||||
PolynomialCoeffs::<F> {
|
||||
coeffs: vec![F::ONE, F::TWO, F::ZERO, F::ZERO]
|
||||
}
|
||||
.trimmed(),
|
||||
PolynomialCoeffs::<F> {
|
||||
coeffs: vec![F::ONE, F::TWO]
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coset_fft() {
|
||||
type F = GoldilocksField;
|
||||
|
||||
let k = 8;
|
||||
let n = 1 << k;
|
||||
let poly = PolynomialCoeffs::new(F::rand_vec(n));
|
||||
let shift = F::rand();
|
||||
let coset_evals = poly.coset_fft(shift).values;
|
||||
|
||||
let generator = F::primitive_root_of_unity(k);
|
||||
let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n)
|
||||
.into_iter()
|
||||
.map(|x| poly.eval(x))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(coset_evals, naive_coset_evals);
|
||||
|
||||
let ifft_coeffs = PolynomialValues::new(coset_evals).coset_ifft(shift);
|
||||
assert_eq!(poly, ifft_coeffs.into());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_coset_ifft() {
|
||||
type F = GoldilocksField;
|
||||
|
||||
let k = 8;
|
||||
let n = 1 << k;
|
||||
let evals = PolynomialValues::new(F::rand_vec(n));
|
||||
let shift = F::rand();
|
||||
let coeffs = evals.coset_ifft(shift);
|
||||
|
||||
let generator = F::primitive_root_of_unity(k);
|
||||
let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n)
|
||||
.into_iter()
|
||||
.map(|x| coeffs.eval(x))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(evals, naive_coset_evals.into());
|
||||
|
||||
let fft_evals = coeffs.coset_fft(shift);
|
||||
assert_eq!(evals, fft_evals);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_polynomial_multiplication() {
|
||||
type F = GoldilocksField;
|
||||
let mut rng = thread_rng();
|
||||
let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000));
|
||||
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
|
||||
let b = PolynomialCoeffs::new(F::rand_vec(b_deg));
|
||||
let m1 = &a * &b;
|
||||
let m2 = &a * &b;
|
||||
for _ in 0..1000 {
|
||||
let x = F::rand();
|
||||
assert_eq!(m1.eval(x), a.eval(x) * b.eval(x));
|
||||
assert_eq!(m2.eval(x), a.eval(x) * b.eval(x));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inv_mod_xn() {
|
||||
type F = GoldilocksField;
|
||||
let mut rng = thread_rng();
|
||||
let a_deg = rng.gen_range(1..1_000);
|
||||
let n = rng.gen_range(1..1_000);
|
||||
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
|
||||
let b = a.inv_mod_xn(n);
|
||||
let mut m = &a * &b;
|
||||
m.coeffs.drain(n..);
|
||||
m.trim();
|
||||
assert_eq!(
|
||||
m,
|
||||
PolynomialCoeffs::new(vec![F::ONE]),
|
||||
"a: {:#?}, b:{:#?}, n:{:#?}, m:{:#?}",
|
||||
a,
|
||||
b,
|
||||
n,
|
||||
m
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_polynomial_long_division() {
|
||||
type F = GoldilocksField;
|
||||
let mut rng = thread_rng();
|
||||
let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000));
|
||||
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
|
||||
let b = PolynomialCoeffs::new(F::rand_vec(b_deg));
|
||||
let (q, r) = a.div_rem_long_division(&b);
|
||||
for _ in 0..1000 {
|
||||
let x = F::rand();
|
||||
assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_polynomial_division() {
|
||||
type F = GoldilocksField;
|
||||
let mut rng = thread_rng();
|
||||
let (a_deg, b_deg) = (rng.gen_range(1..10_000), rng.gen_range(1..10_000));
|
||||
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
|
||||
let b = PolynomialCoeffs::new(F::rand_vec(b_deg));
|
||||
let (q, r) = a.div_rem(&b);
|
||||
for _ in 0..1000 {
|
||||
let x = F::rand();
|
||||
assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_polynomial_division_by_constant() {
|
||||
type F = GoldilocksField;
|
||||
let mut rng = thread_rng();
|
||||
let a_deg = rng.gen_range(1..10_000);
|
||||
let a = PolynomialCoeffs::new(F::rand_vec(a_deg));
|
||||
let b = PolynomialCoeffs::from(vec![F::rand()]);
|
||||
let (q, r) = a.div_rem(&b);
|
||||
for _ in 0..1000 {
|
||||
let x = F::rand();
|
||||
assert_eq!(a.eval(x), b.eval(x) * q.eval(x) + r.eval(x));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_division_by_z_h() {
|
||||
type F = GoldilocksField;
|
||||
let mut rng = thread_rng();
|
||||
let a_deg = rng.gen_range(1..10_000);
|
||||
let n = rng.gen_range(1..a_deg);
|
||||
let mut a = PolynomialCoeffs::new(F::rand_vec(a_deg));
|
||||
a.trim();
|
||||
let z_h = {
|
||||
let mut z_h_vec = vec![F::ZERO; n + 1];
|
||||
z_h_vec[n] = F::ONE;
|
||||
z_h_vec[0] = F::NEG_ONE;
|
||||
PolynomialCoeffs::new(z_h_vec)
|
||||
};
|
||||
let m = &a * &z_h;
|
||||
let now = Instant::now();
|
||||
let mut a_test = m.divide_by_z_h(n);
|
||||
a_test.trim();
|
||||
println!("Division time: {:?}", now.elapsed());
|
||||
assert_eq!(a, a_test);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn divide_zero_poly_by_z_h() {
|
||||
let zero_poly = PolynomialCoeffs::<GoldilocksField>::empty();
|
||||
zero_poly.divide_by_z_h(16);
|
||||
}
|
||||
|
||||
// Test to see which polynomial division method is faster for divisions of the type
|
||||
// `(X^n - 1)/(X - a)
|
||||
#[test]
|
||||
fn test_division_linear() {
|
||||
type F = GoldilocksField;
|
||||
let mut rng = thread_rng();
|
||||
let l = 14;
|
||||
let n = 1 << l;
|
||||
let g = F::primitive_root_of_unity(l);
|
||||
let xn_minus_one = {
|
||||
let mut xn_min_one_vec = vec![F::ZERO; n + 1];
|
||||
xn_min_one_vec[n] = F::ONE;
|
||||
xn_min_one_vec[0] = F::NEG_ONE;
|
||||
PolynomialCoeffs::new(xn_min_one_vec)
|
||||
};
|
||||
|
||||
let a = g.exp_u64(rng.gen_range(0..(n as u64)));
|
||||
let denom = PolynomialCoeffs::new(vec![-a, F::ONE]);
|
||||
let now = Instant::now();
|
||||
xn_minus_one.div_rem(&denom);
|
||||
println!("Division time: {:?}", now.elapsed());
|
||||
let now = Instant::now();
|
||||
xn_minus_one.div_rem_long_division(&denom);
|
||||
println!("Division time: {:?}", now.elapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eq() {
|
||||
type F = GoldilocksField;
|
||||
assert_eq!(
|
||||
PolynomialCoeffs::<F>::new(vec![]),
|
||||
PolynomialCoeffs::new(vec![])
|
||||
);
|
||||
assert_eq!(
|
||||
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
|
||||
PolynomialCoeffs::new(vec![F::ZERO])
|
||||
);
|
||||
assert_eq!(
|
||||
PolynomialCoeffs::<F>::new(vec![]),
|
||||
PolynomialCoeffs::new(vec![F::ZERO])
|
||||
);
|
||||
assert_eq!(
|
||||
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
|
||||
PolynomialCoeffs::new(vec![])
|
||||
);
|
||||
assert_eq!(
|
||||
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
|
||||
PolynomialCoeffs::new(vec![F::ZERO, F::ZERO])
|
||||
);
|
||||
assert_eq!(
|
||||
PolynomialCoeffs::<F>::new(vec![F::ONE]),
|
||||
PolynomialCoeffs::new(vec![F::ONE, F::ZERO])
|
||||
);
|
||||
assert_ne!(
|
||||
PolynomialCoeffs::<F>::new(vec![]),
|
||||
PolynomialCoeffs::new(vec![F::ONE])
|
||||
);
|
||||
assert_ne!(
|
||||
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
|
||||
PolynomialCoeffs::new(vec![F::ZERO, F::ONE])
|
||||
);
|
||||
assert_ne!(
|
||||
PolynomialCoeffs::<F>::new(vec![F::ZERO]),
|
||||
PolynomialCoeffs::new(vec![F::ONE, F::ZERO])
|
||||
);
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user