diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 2d00b403..fd786dfc 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -310,16 +310,16 @@ impl, const D: usize> CircuitBuilder { input: w, }), }); - self.add_generator(CopyGenerator { - src: Target::Wire(Wire { + self.generate_copy( + Target::Wire(Wire { gate: gate_1, input: w, }), - dst: Target::Wire(Wire { + Target::Wire(Wire { gate: gate_2, input: w, }), - }); + ); } } diff --git a/src/field/fft.rs b/src/field/fft.rs index af5c05a7..fa65a5ea 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -6,7 +6,10 @@ use crate::util::{log2_strict, reverse_index_bits}; // TODO: Should really do some "dynamic" dispatch to handle the // different FFT algos rather than C-style enum dispatch. -enum FftStrategy { Classic, Unrolled } +enum FftStrategy { + Classic, + Unrolled, +} const FFT_STRATEGY: FftStrategy = FftStrategy::Classic; @@ -33,7 +36,6 @@ fn fft_classic_root_table(n: usize) -> FftRootTable { root_table } - fn fft_unrolled_root_table(n: usize) -> FftRootTable { // Precompute a table of the roots of unity used in the main // loops. @@ -67,18 +69,20 @@ fn fft_unrolled_root_table(n: usize) -> FftRootTable { fn fft_dispatch( input: Vec, zero_factor: Option, - root_table: Option> + root_table: Option>, ) -> Vec { let n = input.len(); match FFT_STRATEGY { - FftStrategy::Classic - => fft_classic(input, - zero_factor.unwrap_or(0), - root_table.unwrap_or_else(|| fft_classic_root_table(n))), - FftStrategy::Unrolled - => fft_unrolled(input, - zero_factor.unwrap_or(0), - root_table.unwrap_or_else(|| fft_unrolled_root_table(n))) + FftStrategy::Classic => fft_classic( + input, + zero_factor.unwrap_or(0), + root_table.unwrap_or_else(|| fft_classic_root_table(n)), + ), + FftStrategy::Unrolled => fft_unrolled( + input, + zero_factor.unwrap_or(0), + root_table.unwrap_or_else(|| fft_unrolled_root_table(n)), + ), } } @@ -91,10 +95,12 @@ pub fn fft(poly: PolynomialCoeffs) -> PolynomialValues { pub fn fft_with_options( poly: PolynomialCoeffs, zero_factor: Option, - root_table: Option> + root_table: Option>, ) -> PolynomialValues { let PolynomialCoeffs { coeffs } = poly; - PolynomialValues { values: fft_dispatch(coeffs, zero_factor, root_table) } + PolynomialValues { + values: fft_dispatch(coeffs, zero_factor, root_table), + } } #[inline] @@ -105,7 +111,7 @@ pub fn ifft(poly: PolynomialValues) -> PolynomialCoeffs { pub fn ifft_with_options( poly: PolynomialValues, zero_factor: Option, - root_table: Option> + root_table: Option>, ) -> PolynomialCoeffs { let n = poly.len(); let lg_n = log2_strict(n); @@ -136,7 +142,7 @@ pub fn ifft_with_options( pub(crate) fn fft_classic( input: Vec, r: usize, - root_table: FftRootTable + root_table: FftRootTable, ) -> Vec { let mut values = reverse_index_bits(input); @@ -144,7 +150,11 @@ pub(crate) fn fft_classic( let lg_n = log2_strict(n); if root_table.len() != lg_n { - panic!("Expected root table of length {}, but it was {}.", lg_n, root_table.len()); + panic!( + "Expected root table of length {}, but it was {}.", + lg_n, + root_table.len() + ); } // After reverse_index_bits, the only non-zero elements of values @@ -154,7 +164,8 @@ pub(crate) fn fft_classic( // element i*2^r with the value at i*2^r. This corresponds to the // first r rounds of the FFT when there are 2^r zeros at the end // of the original input. - if r > 0 { // if r == 0 then this loop is a noop. + if r > 0 { + // if r == 0 then this loop is a noop. let mask = !((1 << r) - 1); for i in 0..n { values[i] = values[i & mask]; @@ -162,7 +173,7 @@ pub(crate) fn fft_classic( } let mut m = 1 << (r + 1); - for lg_m in (r+1)..=lg_n { + for lg_m in (r + 1)..=lg_n { let half_m = m / 2; for k in (0..n).step_by(m) { for j in 0..half_m { @@ -185,11 +196,7 @@ pub(crate) fn fft_classic( /// The parameter r signifies that the first 1/2^r of the entries of /// input may be non-zero, but the last 1 - 1/2^r entries are /// definitely zero. -fn fft_unrolled( - input: Vec, - r_orig: usize, - root_table: FftRootTable -) -> Vec { +fn fft_unrolled(input: Vec, r_orig: usize, root_table: FftRootTable) -> Vec { let n = input.len(); let lg_n = log2_strict(input.len()); @@ -197,7 +204,7 @@ fn fft_unrolled( // FFT of a constant polynomial (including zero) is itself. if n < 2 { - return values + return values; } // The 'm' corresponds to the specialisation from the 'm' in the @@ -206,7 +213,8 @@ fn fft_unrolled( // (See comment in fft_classic near same code.) let mut r = r_orig; let mut m = 1 << r; - if r > 0 { // if r == 0 then this loop is a noop. + if r > 0 { + // if r == 0 then this loop is a noop. let mask = !((1 << r) - 1); for i in 0..n { values[i] = values[i & mask]; @@ -225,11 +233,15 @@ fn fft_unrolled( } if n == 2 { - return values + return values; } if root_table.len() != (lg_n - 1) { - panic!("Expected root table of length {}, but it was {}.", lg_n, root_table.len()); + panic!( + "Expected root table of length {}, but it was {}.", + lg_n, + root_table.len() + ); } // m = 2 @@ -253,7 +265,7 @@ fn fft_unrolled( // m >= 4 for lg_m in r..lg_n { - for k in (0..n).step_by(2*m) { + for k in (0..n).step_by(2 * m) { // Unrolled the commented loop by groups of 4 and // rearranged the lines. Improves runtime by about // 10%. @@ -294,11 +306,10 @@ fn fft_unrolled( values } - #[cfg(test)] mod tests { use crate::field::crandall_field::CrandallField; - use crate::field::fft::{fft, ifft, fft_with_options}; + use crate::field::fft::{fft, fft_with_options, ifft}; use crate::field::field::Field; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::util::{log2_ceil, log2_strict}; @@ -328,7 +339,10 @@ mod tests { for r in 0..4 { // expand ceofficients by factor 2^r by filling with zeros let zero_tail = coefficients.clone().lde(r); - assert_eq!(fft(zero_tail.clone()), fft_with_options(zero_tail, Some(r), None)); + assert_eq!( + fft(zero_tail.clone()), + fft_with_options(zero_tail, Some(r), None) + ); } } diff --git a/src/field/field_testing.rs b/src/field/field_testing.rs index 7190684f..1f5bff6f 100644 --- a/src/field/field_testing.rs +++ b/src/field/field_testing.rs @@ -323,7 +323,7 @@ macro_rules! test_arithmetic { let v = ::PrimeField::TWO_ADICITY; - for e in [0, 1, 2, 3, 4, v - 2, v - 1, v, v + 1, v + 2, 123*v] { + for e in [0, 1, 2, 3, 4, v - 2, v - 1, v, v + 1, v + 2, 123 * v] { let x = F::TWO.exp(e as u64).inverse(); let y = F::inverse_2exp(e); assert_eq!(x, y); diff --git a/src/gadgets/insert.rs b/src/gadgets/insert.rs index 64cf7299..1f69cb24 100644 --- a/src/gadgets/insert.rs +++ b/src/gadgets/insert.rs @@ -1,9 +1,42 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; +use crate::generator::NonzeroTestGenerator; use crate::target::Target; impl, const D: usize> CircuitBuilder { + /// Evaluates to 0 if `x` equals zero, 1 otherwise. + /// From section 2 of https://github.com/mir-protocol/r1cs-workshop/blob/master/workshop.pdf, + /// based on an idea from https://eprint.iacr.org/2012/598.pdf. + pub fn is_nonzero(&mut self, x: Target) -> Target { + // Dummy variable. + let m = self.add_virtual_target(); + + // The prover sets this the dummy variable to 1/x if x != 0, or to an arbitrary value if + // x == 0. + self.add_generator(NonzeroTestGenerator { + to_test: x, + dummy: m, + }); + + // Evaluates to (0) * (0) = 0 if x == 0 and (x) * (1/x) = 1 otherwise. + let y = self.mul(x, m); + + // Enforce that (1 - y) * x == 0. + let prod = self.arithmetic(F::NEG_ONE, x, y, F::ONE, x); + self.assert_zero(prod); + + y + } + + /// Evaluates to 1 if `x` and `y` are equal, 0 otherwise. + pub fn is_equal(&mut self, x: Target, y: Target) -> Target { + let difference = self.sub(x, y); + let not_equal = self.is_nonzero(difference); + let one = self.one(); + self.sub(one, not_equal) + } + /// Inserts a `Target` in a vector at a non-deterministic index. This is done by rotating to the /// left, inserting at 0 and then rotating to the right. /// Note: `index` is not range-checked. @@ -13,9 +46,29 @@ impl, const D: usize> CircuitBuilder { element: ExtensionTarget, v: Vec>, ) -> Vec> { - let mut v = self.rotate_left(index, &v); - v.insert(0, element); - self.rotate_right(index, &v) + let mut already_inserted = self.zero(); + let mut new_list = Vec::new(); + + for i in 0..v.len() { + let one = self.one(); + + let cur_index = self.constant(F::from_canonical_usize(i)); + let insert_here = self.is_equal(cur_index, index); + + let mut new_item = self.zero_extension(); + new_item = self.scalar_mul_add_extension(insert_here, element, new_item); + if i > 0 { + new_item = self.scalar_mul_add_extension(already_inserted, v[i - 1], new_item); + } + already_inserted = self.add(already_inserted, insert_here); + + let not_already_inserted = self.sub(one, already_inserted); + new_item = self.scalar_mul_add_extension(not_already_inserted, v[i], new_item); + + new_list.push(new_item); + } + + new_list } } #[cfg(test)] diff --git a/src/generator.rs b/src/generator.rs index a2b35a53..a47c5267 100644 --- a/src/generator.rs +++ b/src/generator.rs @@ -130,3 +130,27 @@ impl SimpleGenerator for RandomValueGenerator { PartialWitness::singleton_target(self.target, random_value) } } + +/// A generator for testing if a value equals zero +pub(crate) struct NonzeroTestGenerator { + pub(crate) to_test: Target, + pub(crate) dummy: Target, +} + +impl SimpleGenerator for NonzeroTestGenerator { + fn dependencies(&self) -> Vec { + vec![self.to_test] + } + + fn run_once(&self, witness: &PartialWitness) -> PartialWitness { + let to_test_value = witness.get_target(self.to_test); + + let dummy_value = if to_test_value == F::ZERO { + F::ONE + } else { + to_test_value.inverse() + }; + + PartialWitness::singleton_target(self.dummy, dummy_value) + } +} diff --git a/src/polynomial/polynomial.rs b/src/polynomial/polynomial.rs index 81d07b8f..aa06d641 100644 --- a/src/polynomial/polynomial.rs +++ b/src/polynomial/polynomial.rs @@ -1,6 +1,7 @@ use std::cmp::max; use std::iter::Sum; use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; +use std::time::Instant; use anyhow::{ensure, Result}; diff --git a/src/target.rs b/src/target.rs index e765f7eb..52be8b5a 100644 --- a/src/target.rs +++ b/src/target.rs @@ -7,11 +7,15 @@ use crate::wire::Wire; #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] pub enum Target { Wire(Wire), - PublicInput { index: usize }, + PublicInput { + index: usize, + }, /// A target that doesn't have any inherent location in the witness (but it can be copied to /// another target that does). This is useful for representing intermediate values in witness /// generation. - VirtualTarget { index: usize }, + VirtualTarget { + index: usize, + }, } impl Target { diff --git a/src/util/mod.rs b/src/util/mod.rs index 8fd60d53..ee3b8440 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -77,7 +77,9 @@ pub(crate) fn reverse_bits(n: usize, num_bits: usize) -> usize { // to plain '>>' is to accommodate the case n == num_bits == 0, // which would become `0 >> 64`. Rust thinks that any shift of 64 // bits causes overflow, even when the argument is zero. - n.reverse_bits().overflowing_shr(usize::BITS - num_bits as u32).0 + n.reverse_bits() + .overflowing_shr(usize::BITS - num_bits as u32) + .0 } #[cfg(test)]