diff --git a/src/circuit_builder.rs b/src/circuit_builder.rs index 2e12ce0f..1cccc68c 100644 --- a/src/circuit_builder.rs +++ b/src/circuit_builder.rs @@ -275,7 +275,7 @@ impl, const D: usize> CircuitBuilder { self.constant(F::TWO) } - /// Returns a routable target with a value of `ORDER - 1`. + /// Returns a routable target with a value of `order() - 1`. pub fn neg_one(&mut self) -> Target { self.constant(F::NEG_ONE) } diff --git a/src/field/cosets.rs b/src/field/cosets.rs index 13ef2e21..f2edd892 100644 --- a/src/field/cosets.rs +++ b/src/field/cosets.rs @@ -1,12 +1,14 @@ +use num::bigint::BigUint; + use crate::field::field::Field; /// Finds a set of shifts that result in unique cosets for the multiplicative subgroup of size /// `2^subgroup_bits`. pub(crate) fn get_unique_coset_shifts(subgroup_size: usize, num_shifts: usize) -> Vec { // From Lagrange's theorem. - let num_cosets = (F::ORDER - 1) / (subgroup_size as u64); + let num_cosets = (F::order() - 1u32) / (subgroup_size as u32); assert!( - num_shifts as u64 <= num_cosets, + BigUint::from(num_shifts) <= num_cosets, "The subgroup does not have enough distinct cosets" ); diff --git a/src/field/crandall_field.rs b/src/field/crandall_field.rs index 9d38ff1a..051b0fd0 100644 --- a/src/field/crandall_field.rs +++ b/src/field/crandall_field.rs @@ -4,7 +4,10 @@ use std::hash::{Hash, Hasher}; use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +use itertools::Itertools; +use num::bigint::BigUint; use num::Integer; +use rand::Rng; use serde::{Deserialize, Serialize}; use crate::field::extension_field::quadratic::QuadraticCrandallField; @@ -12,6 +15,8 @@ use crate::field::extension_field::quartic::QuarticCrandallField; use crate::field::extension_field::{Extendable, Frobenius}; use crate::field::field::Field; +const FIELD_ORDER: u64 = 18446744071293632513; + /// EPSILON = 9 * 2**28 - 1 const EPSILON: u64 = 2415919103; @@ -142,15 +147,18 @@ impl Field for CrandallField { const ZERO: Self = Self(0); const ONE: Self = Self(1); const TWO: Self = Self(2); - const NEG_ONE: Self = Self(Self::ORDER - 1); + const NEG_ONE: Self = Self(FIELD_ORDER - 1); - const ORDER: u64 = 18446744071293632513; const TWO_ADICITY: usize = 28; - const CHARACTERISTIC: u64 = Self::ORDER; + const CHARACTERISTIC: u64 = FIELD_ORDER; const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(5); const POWER_OF_TWO_GENERATOR: Self = Self(10281950781551402419); + fn order() -> BigUint { + BigUint::from(FIELD_ORDER) + } + #[inline] fn square(&self) -> Self { *self * *self @@ -170,7 +178,7 @@ impl Field for CrandallField { // Based on Algorithm 16 of "Efficient Software-Implementation of Finite Fields with // Applications to Cryptography". - let p = Self::ORDER; + let p = FIELD_ORDER; let mut u = self.to_canonical_u64(); let mut v = p; let mut b = 1u64; @@ -228,8 +236,8 @@ impl Field for CrandallField { fn to_canonical_u64(&self) -> u64 { let mut c = self.0; // We only need one condition subtraction, since 2 * ORDER would not fit in a u64. - if c >= Self::ORDER { - c -= Self::ORDER; + if c >= FIELD_ORDER { + c -= FIELD_ORDER; } c } @@ -239,6 +247,14 @@ impl Field for CrandallField { Self(n) } + fn to_canonical_biguint(&self) -> BigUint { + BigUint::from(self.to_canonical_u64()) + } + + fn from_canonical_biguint(n: BigUint) -> Self { + Self(n.iter_u64_digits().next().unwrap_or(0)) + } + fn cube_root(&self) -> Self { let x0 = *self; let x1 = x0.square(); @@ -326,6 +342,10 @@ impl Field for CrandallField { } result } + + fn rand_from_rng(rng: &mut R) -> Self { + Self::from_canonical_u64(rng.gen_range(0, FIELD_ORDER)) + } } impl Neg for CrandallField { @@ -336,7 +356,7 @@ impl Neg for CrandallField { if self.is_zero() { Self::ZERO } else { - Self(Self::ORDER - self.to_canonical_u64()) + Self(FIELD_ORDER - self.to_canonical_u64()) } } } @@ -348,7 +368,7 @@ impl Add for CrandallField { #[allow(clippy::suspicious_arithmetic_impl)] fn add(self, rhs: Self) -> Self { let (sum, over) = self.0.overflowing_add(rhs.0); - Self(sum.overflowing_sub((over as u64) * Self::ORDER).0) + Self(sum.overflowing_sub((over as u64) * FIELD_ORDER).0) } } @@ -371,7 +391,7 @@ impl Sub for CrandallField { #[allow(clippy::suspicious_arithmetic_impl)] fn sub(self, rhs: Self) -> Self { let (diff, under) = self.0.overflowing_sub(rhs.to_canonical_u64()); - Self(diff.overflowing_add((under as u64) * Self::ORDER).0) + Self(diff.overflowing_add((under as u64) * FIELD_ORDER).0) } } @@ -452,7 +472,8 @@ impl Frobenius<1> for CrandallField {} #[cfg(test)] mod tests { - use crate::test_arithmetic; + use crate::{test_field_arithmetic, test_prime_field_arithmetic}; - test_arithmetic!(crate::field::crandall_field::CrandallField); + test_prime_field_arithmetic!(crate::field::crandall_field::CrandallField); + test_field_arithmetic!(crate::field::crandall_field::CrandallField); } diff --git a/src/field/extension_field/mod.rs b/src/field/extension_field/mod.rs index 2a176fe9..7d706237 100644 --- a/src/field/extension_field/mod.rs +++ b/src/field/extension_field/mod.rs @@ -34,8 +34,8 @@ pub trait Frobenius: OEF { return self.repeated_frobenius(count % D); } let arr = self.to_basefield_array(); - let k = (Self::BaseField::ORDER - 1) / (D as u64); - let z0 = Self::W.exp(k * count as u64); + let k = (Self::BaseField::order() - 1u32) / (D as u64); + let z0 = Self::W.exp_biguint(&(k * count as u64)); let mut res = [Self::BaseField::ZERO; D]; for (i, z) in z0.powers().take(D).enumerate() { res[i] = arr[i] * z; diff --git a/src/field/extension_field/quadratic.rs b/src/field/extension_field/quadratic.rs index 256803ab..5324ad2a 100644 --- a/src/field/extension_field/quadratic.rs +++ b/src/field/extension_field/quadratic.rs @@ -3,6 +3,8 @@ use std::hash::Hash; use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +use itertools::Itertools; +use num::bigint::BigUint; use rand::Rng; use serde::{Deserialize, Serialize}; @@ -51,9 +53,7 @@ impl Field for QuadraticCrandallField { const TWO: Self = Self([CrandallField::TWO, CrandallField::ZERO]); const NEG_ONE: Self = Self([CrandallField::NEG_ONE, CrandallField::ZERO]); - const CHARACTERISTIC: u64 = CrandallField::ORDER; - // Does not fit in 64-bits. - const ORDER: u64 = 0; + const CHARACTERISTIC: u64 = CrandallField::CHARACTERISTIC; const TWO_ADICITY: usize = 29; const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([ CrandallField(6483724566312148654), @@ -65,6 +65,10 @@ impl Field for QuadraticCrandallField { const POWER_OF_TWO_GENERATOR: Self = Self([CrandallField::ZERO, CrandallField(14420468973723774561)]); + fn order() -> BigUint { + CrandallField::order() * CrandallField::order() + } + // Algorithm 11.3.4 in Handbook of Elliptic and Hyperelliptic Curve Cryptography. fn try_inverse(&self) -> Option { if self.is_zero() { @@ -86,6 +90,24 @@ impl Field for QuadraticCrandallField { >::BaseField::from_canonical_u64(n).into() } + fn to_canonical_biguint(&self) -> BigUint { + let first = self.0[0].to_canonical_biguint(); + let second = self.0[1].to_canonical_biguint(); + let combined = second * Self::CHARACTERISTIC + first; + + combined + } + + fn from_canonical_biguint(n: BigUint) -> Self { + let smaller = n.clone() % Self::CHARACTERISTIC; + let larger = n.clone() / Self::CHARACTERISTIC; + + Self([ + >::BaseField::from_canonical_biguint(smaller), + >::BaseField::from_canonical_biguint(larger), + ]) + } + fn rand_from_rng(rng: &mut R) -> Self { Self([ >::BaseField::rand_from_rng(rng), @@ -200,6 +222,7 @@ mod tests { use crate::field::extension_field::quadratic::QuadraticCrandallField; use crate::field::extension_field::{FieldExtension, Frobenius}; use crate::field::field::Field; + use crate::test_field_arithmetic; #[test] fn test_add_neg_sub_mul() { @@ -238,14 +261,14 @@ mod tests { type F = QuadraticCrandallField; let x = F::rand(); assert_eq!( - x.exp(>::BaseField::ORDER), + x.exp_biguint(&>::BaseField::order()), x.frobenius() ); } #[test] fn test_field_order() { - // F::ORDER = 340282366831806780677557380898690695169 = 18446744071293632512 *18446744071293632514 + 1 + // F::order() = 340282366831806780677557380898690695169 = 18446744071293632512 *18446744071293632514 + 1 type F = QuadraticCrandallField; let x = F::rand(); assert_eq!( @@ -257,7 +280,7 @@ mod tests { #[test] fn test_power_of_two_gen() { type F = QuadraticCrandallField; - // F::ORDER = 2^29 * 2762315674048163 * 229454332791453 + 1 + // F::order() = 2^29 * 2762315674048163 * 229454332791453 + 1 assert_eq!( F::MULTIPLICATIVE_GROUP_GENERATOR .exp(2762315674048163) @@ -270,4 +293,6 @@ mod tests { >::BaseField::POWER_OF_TWO_GENERATOR.into() ); } + + test_field_arithmetic!(crate::field::extension_field::quadratic::QuadraticCrandallField); } diff --git a/src/field/extension_field/quartic.rs b/src/field/extension_field/quartic.rs index 9390c521..f38f103a 100644 --- a/src/field/extension_field/quartic.rs +++ b/src/field/extension_field/quartic.rs @@ -3,6 +3,9 @@ use std::hash::Hash; use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +use itertools::Itertools; +use num::bigint::BigUint; +use num::traits::Pow; use rand::Rng; use serde::{Deserialize, Serialize}; @@ -73,9 +76,8 @@ impl Field for QuarticCrandallField { CrandallField::ZERO, ]); - const CHARACTERISTIC: u64 = CrandallField::ORDER; + const CHARACTERISTIC: u64 = CrandallField::CHARACTERISTIC; // Does not fit in 64-bits. - const ORDER: u64 = 0; const TWO_ADICITY: usize = 30; const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self([ CrandallField(12476589904174392631), @@ -93,6 +95,10 @@ impl Field for QuarticCrandallField { CrandallField(15170983443234254033), ]); + fn order() -> BigUint { + CrandallField::order().pow(4u32) + } + // Algorithm 11.3.4 in Handbook of Elliptic and Hyperelliptic Curve Cryptography. fn try_inverse(&self) -> Option { if self.is_zero() { @@ -117,6 +123,40 @@ impl Field for QuarticCrandallField { >::BaseField::from_canonical_u64(n).into() } + fn to_canonical_biguint(&self) -> BigUint { + let first = self.0[0].to_canonical_biguint(); + let second = self.0[1].to_canonical_biguint(); + let third = self.0[2].to_canonical_biguint(); + let fourth = self.0[3].to_canonical_biguint(); + + let mut combined = fourth; + combined *= Self::CHARACTERISTIC; + combined += third; + combined *= Self::CHARACTERISTIC; + combined += second; + combined *= Self::CHARACTERISTIC; + combined += first; + + combined + } + + fn from_canonical_biguint(n: BigUint) -> Self { + let first = &n % Self::CHARACTERISTIC; + let mut remaining = &n / Self::CHARACTERISTIC; + let second = &remaining % Self::CHARACTERISTIC; + remaining = remaining / Self::CHARACTERISTIC; + let third = &remaining % Self::CHARACTERISTIC; + remaining = remaining / Self::CHARACTERISTIC; + let fourth = &remaining % Self::CHARACTERISTIC; + + Self([ + >::BaseField::from_canonical_biguint(first), + >::BaseField::from_canonical_biguint(second), + >::BaseField::from_canonical_biguint(third), + >::BaseField::from_canonical_biguint(fourth), + ]) + } + fn rand_from_rng(rng: &mut R) -> Self { Self([ >::BaseField::rand_from_rng(rng), @@ -249,6 +289,7 @@ mod tests { use crate::field::extension_field::quartic::QuarticCrandallField; use crate::field::extension_field::{FieldExtension, Frobenius}; use crate::field::field::Field; + use crate::test_field_arithmetic; fn exp_naive(x: F, power: u128) -> F { let mut current = x; @@ -301,7 +342,7 @@ mod tests { const D: usize = 4; let x = F::rand(); assert_eq!( - exp_naive(x, >::BaseField::ORDER as u128), + x.exp_biguint(&>::BaseField::order()), x.frobenius() ); for count in 2..D { @@ -314,7 +355,7 @@ mod tests { #[test] fn test_field_order() { - // F::ORDER = 340282366831806780677557380898690695168 * 340282366831806780677557380898690695170 + 1 + // F::order() = 340282366831806780677557380898690695168 * 340282366831806780677557380898690695170 + 1 type F = QuarticCrandallField; let x = F::rand(); assert_eq!( @@ -329,7 +370,7 @@ mod tests { #[test] fn test_power_of_two_gen() { type F = QuarticCrandallField; - // F::ORDER = 2^30 * 1090552343587053358839971118999869 * 98885475095492590491252558464653635 + 1 + // F::order() = 2^30 * 1090552343587053358839971118999869 * 98885475095492590491252558464653635 + 1 assert_eq!( exp_naive( exp_naive( @@ -346,4 +387,6 @@ mod tests { >::BaseField::POWER_OF_TWO_GENERATOR.into() ); } + + test_field_arithmetic!(crate::field::extension_field::quartic::QuarticCrandallField); } diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index 455ee38f..6083e2c4 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -31,8 +31,8 @@ impl ExtensionTarget { return self.repeated_frobenius(count % D, builder); } let arr = self.to_target_array(); - let k = (F::ORDER - 1) / (D as u64); - let z0 = F::Extension::W.exp(k * count as u64); + let k = (F::order() - 1u32) / (D as u64); + let z0 = F::Extension::W.exp_biguint(&(k * count as u64)); let zs = z0 .powers() .take(D) diff --git a/src/field/field.rs b/src/field/field.rs index 28a52202..fd5f8ac1 100644 --- a/src/field/field.rs +++ b/src/field/field.rs @@ -4,7 +4,8 @@ use std::hash::Hash; use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -use num::Integer; +use num::bigint::BigUint; +use num::{Integer, One, Zero}; use rand::Rng; use serde::de::DeserializeOwned; use serde::Serialize; @@ -44,7 +45,6 @@ pub trait Field: const NEG_ONE: Self; const CHARACTERISTIC: u64; - const ORDER: u64; const TWO_ADICITY: usize; /// Generator of the entire multiplicative group, i.e. all non-zero elements. @@ -52,6 +52,8 @@ pub trait Field: /// Generator of a multiplicative subgroup of order `2^TWO_ADICITY`. const POWER_OF_TWO_GENERATOR: Self; + fn order() -> BigUint; + fn is_zero(&self) -> bool { *self == Self::ZERO } @@ -183,6 +185,12 @@ pub trait Field: Self::from_canonical_u64(n as u64) } + fn to_canonical_biguint(&self) -> BigUint; + + fn from_canonical_biguint(n: BigUint) -> Self; + + fn rand_from_rng(rng: &mut R) -> Self; + fn bits(&self) -> usize { bits_u64(self.to_canonical_u64()) } @@ -212,18 +220,33 @@ pub trait Field: self.exp(power as u64) } + fn exp_biguint(&self, power: &BigUint) -> Self { + let digits = power.to_u32_digits(); + let radix = 1u64 << 32; + + let mut result = Self::ONE; + for (radix_power, &digit) in digits.iter().enumerate() { + let mut current = self.exp_u32(digit); + for _ in 0..radix_power { + current = current.exp(radix); + } + result *= current; + } + result + } + /// Returns whether `x^power` is a permutation of this field. fn is_monomial_permutation(power: u64) -> bool { match power { 0 => false, 1 => true, - _ => (Self::ORDER - 1).gcd(&power) == 1, + _ => (Self::order() - 1u32).gcd(&BigUint::from(power)).is_one(), } } fn kth_root(&self, k: u64) -> Self { - let p = Self::ORDER; - let p_minus_1 = p - 1; + let p = Self::order().clone(); + let p_minus_1 = &p - 1u32; debug_assert!( Self::is_monomial_permutation(k), "Not a permutation of this field" @@ -236,10 +259,10 @@ pub trait Field: // x^((p + n(p - 1))/k)^k = x, // implying that x^((p + n(p - 1))/k) is a k'th root of x. for n in 0..k { - let numerator = p as u128 + n as u128 * p_minus_1 as u128; - if numerator % k as u128 == 0 { - let power = (numerator / k as u128) as u64 % p_minus_1; - return self.exp(power); + let numerator = &p + &p_minus_1 * n; + if (&numerator % k).is_zero() { + let power = (numerator / k) % p_minus_1; + return self.exp_biguint(&power); } } panic!( @@ -292,10 +315,6 @@ pub trait Field: Self::mds(vec.to_vec()).try_into().unwrap() } - fn rand_from_rng(rng: &mut R) -> Self { - Self::from_canonical_u64(rng.gen_range(0, Self::ORDER)) - } - fn rand() -> Self { Self::rand_from_rng(&mut rand::thread_rng()) } diff --git a/src/field/field_testing.rs b/src/field/field_testing.rs index 1f5bff6f..ffab1d9d 100644 --- a/src/field/field_testing.rs +++ b/src/field/field_testing.rs @@ -1,18 +1,21 @@ +use num::{bigint::BigUint, Zero}; + use crate::field::field::Field; -use crate::util::{bits_u64, ceil_div_usize}; +use crate::util::ceil_div_usize; /// Generates a series of non-negative integers less than /// `modulus` which cover a range of values and which will /// generate lots of carries, especially at `word_bits` word /// boundaries. -pub fn test_inputs(modulus: u64, word_bits: usize) -> Vec { - assert!(word_bits == 32 || word_bits == 64); - let modwords = ceil_div_usize(bits_u64(modulus), word_bits); +pub fn test_inputs(modulus: BigUint, word_bits: usize) -> Vec { + //assert!(word_bits == 32 || word_bits == 64); + let modwords = ceil_div_usize(modulus.bits() as usize, word_bits); // Start with basic set close to zero: 0 .. 10 const BIGGEST_SMALL: u32 = 10; - let smalls: Vec<_> = (0..BIGGEST_SMALL).map(u64::from).collect(); + let smalls: Vec<_> = (0..BIGGEST_SMALL).map(BigUint::from).collect(); // ... and close to MAX: MAX - x - let word_max = (1u64 << word_bits) - 1; + let word_max = (BigUint::from(1u32) << word_bits) - 1u32; + let multiple_words_max = (BigUint::from(1u32) << modwords * word_bits) - 1u32; let bigs = smalls.iter().map(|x| &word_max - x).collect(); let one_words = [smalls, bigs].concat(); // For each of the one word inputs above, create a new one at word i. @@ -22,28 +25,28 @@ pub fn test_inputs(modulus: u64, word_bits: usize) -> Vec { one_words .iter() .map(|x| x << (word_bits * i)) - .collect::>() + .collect::>() }) .collect(); - let basic_inputs: Vec = [one_words, multiple_words].concat(); + let basic_inputs: Vec = [one_words, multiple_words].concat(); // Biggest value that will fit in `modwords` words // Inputs 'difference from' maximum value let diff_max = basic_inputs .iter() - .map(|&x| u64::MAX - x) - .filter(|&x| x < modulus) + .map(|x| &multiple_words_max - x) + .filter(|x| x < &modulus) .collect(); // Inputs 'difference from' modulus value let diff_mod = basic_inputs .iter() - .filter(|&&x| x < modulus && x != 0) - .map(|&x| modulus - x) + .filter(|&x| x < &modulus && !x.is_zero()) + .map(|x| &modulus - x) .collect(); let basics = basic_inputs .into_iter() - .filter(|&x| x < modulus) - .collect::>(); + .filter(|x| x < &modulus) + .collect::>(); [basics, diff_max, diff_mod].concat() // // There should be a nicer way to express the code above; something @@ -59,20 +62,21 @@ pub fn test_inputs(modulus: u64, word_bits: usize) -> Vec { /// coordinate-wise to the inputs from `test_inputs(modulus, /// word_bits)` and panic if the two resulting vectors differ. pub fn run_unaryop_test_cases( - modulus: u64, + modulus: BigUint, word_bits: usize, op: UnaryOp, expected_op: ExpectedOp, ) where F: Field, UnaryOp: Fn(F) -> F, - ExpectedOp: Fn(u64) -> u64, + ExpectedOp: Fn(BigUint) -> BigUint, { let inputs = test_inputs(modulus, word_bits); - let expected: Vec<_> = inputs.iter().map(|&x| expected_op(x)).collect(); + let expected: Vec<_> = inputs.iter().map(|x| expected_op(x.clone())).collect(); let output: Vec<_> = inputs .iter() - .map(|&x| op(F::from_canonical_u64(x)).to_canonical_u64()) + .cloned() + .map(|x| op(F::from_canonical_biguint(x)).to_canonical_biguint()) .collect(); // Compare expected outputs with actual outputs for i in 0..inputs.len() { @@ -90,14 +94,14 @@ pub fn run_unaryop_test_cases( /// `inputs.len()`. Panic if the two functions ever give /// different answers. pub fn run_binaryop_test_cases( - modulus: u64, + modulus: BigUint, word_bits: usize, op: BinaryOp, expected_op: ExpectedOp, ) where F: Field, BinaryOp: Fn(F, F) -> F, - ExpectedOp: Fn(u64, u64) -> u64, + ExpectedOp: Fn(BigUint, BigUint) -> BigUint, { let inputs = test_inputs(modulus, word_bits); @@ -122,8 +126,12 @@ pub fn run_binaryop_test_cases( let output: Vec<_> = inputs .iter() .zip(shifted_inputs.clone()) - .map(|(&x, &y)| { - op(F::from_canonical_u64(x), F::from_canonical_u64(y)).to_canonical_u64() + .map(|(x, y)| { + op( + F::from_canonical_biguint(x.clone()), + F::from_canonical_biguint(y.clone()), + ) + .to_canonical_biguint() }) .collect(); @@ -139,108 +147,14 @@ pub fn run_binaryop_test_cases( } #[macro_export] -macro_rules! test_arithmetic { +macro_rules! test_field_arithmetic { ($field:ty) => { - mod arithmetic { - use std::ops::{Add, Mul, Neg, Sub}; + mod field_arithmetic { + use num::{bigint::BigUint, One, Zero}; + use rand::{thread_rng, Rng}; use crate::field::field::Field; - // Can be 32 or 64; doesn't have to be computer's actual word - // bits. Choosing 32 gives more tests... - const WORD_BITS: usize = 32; - - #[test] - fn arithmetic_addition() { - let modulus = <$field>::ORDER; - crate::field::field_testing::run_binaryop_test_cases( - modulus, - WORD_BITS, - <$field>::add, - |x, y| { - let (z, over) = x.overflowing_add(y); - if over { - z.overflowing_sub(modulus).0 - } else if z >= modulus { - z - modulus - } else { - z - } - }, - ) - } - - #[test] - fn arithmetic_subtraction() { - let modulus = <$field>::ORDER; - crate::field::field_testing::run_binaryop_test_cases( - modulus, - WORD_BITS, - <$field>::sub, - |x, y| { - if x >= y { - x - y - } else { - &modulus - y + x - } - }, - ) - } - - #[test] - fn arithmetic_negation() { - let modulus = <$field>::ORDER; - crate::field::field_testing::run_unaryop_test_cases( - modulus, - WORD_BITS, - <$field>::neg, - |x| { - if x == 0 { - 0 - } else { - modulus - x - } - }, - ) - } - - #[test] - fn arithmetic_multiplication() { - let modulus = <$field>::ORDER; - crate::field::field_testing::run_binaryop_test_cases( - modulus, - WORD_BITS, - <$field>::mul, - |x, y| ((x as u128) * (y as u128) % (modulus as u128)) as u64, - ) - } - - #[test] - fn arithmetic_square() { - let modulus = <$field>::ORDER; - crate::field::field_testing::run_unaryop_test_cases( - modulus, - WORD_BITS, - |x: $field| x.square(), - |x| ((x as u128) * (x as u128) % (modulus as u128)) as u64, - ) - } - - #[test] - fn inversion() { - let zero = <$field>::ZERO; - let one = <$field>::ONE; - let order = <$field>::ORDER; - - assert_eq!(zero.try_inverse(), None); - - for &x in &[1, 2, 3, order - 3, order - 2, order - 1] { - let x = <$field>::from_canonical_u64(x); - let inv = x.inverse(); - assert_eq!(x * inv, one); - } - } - #[test] fn batch_inversion() { let xs = (1..=3) @@ -264,10 +178,16 @@ macro_rules! test_arithmetic { #[test] fn negation() { let zero = <$field>::ZERO; - let order = <$field>::ORDER; + let order = <$field>::order(); - for &i in &[0, 1, 2, order - 2, order - 1] { - let i_f = <$field>::from_canonical_u64(i); + for i in [ + BigUint::zero(), + BigUint::one(), + BigUint::from(2u32), + &order - 1u32, + &order - 2u32, + ] { + let i_f = <$field>::from_canonical_biguint(i); assert_eq!(i_f + -i_f, zero); } } @@ -307,13 +227,20 @@ macro_rules! test_arithmetic { } #[test] - fn subtraction() { + fn exponentiation_large() { type F = $field; - let (a, b) = (F::from_canonical_u64((F::ORDER + 1) / 2), F::TWO); - let x = a * b; - assert_eq!(x, F::ONE); - assert_eq!(F::ZERO - x, F::NEG_ONE); + let mut rng = rand::thread_rng(); + + let base = F::rand(); + let pow = BigUint::from(rng.gen::()); + let cycles = rng.gen::(); + let mul_group_order = F::order() - 1u32; + let big_pow = &pow + &mul_group_order * cycles; + let big_pow_wrong = &pow + &mul_group_order * cycles + 1u32; + + assert_eq!(base.exp_biguint(&pow), base.exp_biguint(&big_pow)); + assert_ne!(base.exp_biguint(&pow), base.exp_biguint(&big_pow_wrong)); } #[test] @@ -332,3 +259,122 @@ macro_rules! test_arithmetic { } }; } + +#[macro_export] +macro_rules! test_prime_field_arithmetic { + ($field:ty) => { + mod prime_field_arithmetic { + use std::ops::{Add, Mul, Neg, Sub}; + + use num::{bigint::BigUint, One, Zero}; + + use crate::field::field::Field; + + // Can be 32 or 64; doesn't have to be computer's actual word + // bits. Choosing 32 gives more tests... + const WORD_BITS: usize = 32; + + #[test] + fn arithmetic_addition() { + let modulus = <$field>::order(); + crate::field::field_testing::run_binaryop_test_cases( + modulus.clone(), + WORD_BITS, + <$field>::add, + |x, y| (&x + &y) % &modulus, + ) + } + + #[test] + fn arithmetic_subtraction() { + let modulus = <$field>::order(); + crate::field::field_testing::run_binaryop_test_cases( + modulus.clone(), + WORD_BITS, + <$field>::sub, + |x, y| { + if x >= y { + &x - &y + } else { + &modulus - &y + &x + } + }, + ) + } + + #[test] + fn arithmetic_negation() { + let modulus = <$field>::order(); + crate::field::field_testing::run_unaryop_test_cases( + modulus.clone(), + WORD_BITS, + <$field>::neg, + |x| { + if x.is_zero() { + BigUint::zero() + } else { + &modulus - &x + } + }, + ) + } + + #[test] + fn arithmetic_multiplication() { + let modulus = <$field>::order(); + crate::field::field_testing::run_binaryop_test_cases( + modulus.clone(), + WORD_BITS, + <$field>::mul, + |x, y| &x * &y % &modulus, + ) + } + + #[test] + fn arithmetic_square() { + let modulus = <$field>::order(); + crate::field::field_testing::run_unaryop_test_cases( + modulus.clone(), + WORD_BITS, + |x: $field| x.square(), + |x| (&x * &x) % &modulus, + ) + } + + #[test] + fn inversion() { + let zero = <$field>::ZERO; + let one = <$field>::ONE; + let order = <$field>::order(); + + assert_eq!(zero.try_inverse(), None); + + for x in [ + BigUint::one(), + BigUint::from(2u32), + BigUint::from(3u32), + &order - 3u32, + &order - 2u32, + &order - 1u32, + ] { + let x = <$field>::from_canonical_biguint(x); + let inv = x.inverse(); + assert_eq!(x * inv, one); + } + } + + #[test] + fn subtraction() { + type F = $field; + + let (a, b) = ( + F::from_canonical_biguint((F::order() + 1u32) / 2u32), + F::TWO, + ); + let x = a * b; + assert_eq!(x, F::ONE); + assert_eq!(F::ZERO - x, F::NEG_ONE); + } + } + }; +} diff --git a/src/fri/prover.rs b/src/fri/prover.rs index 5a8f09e5..a4147da7 100644 --- a/src/fri/prover.rs +++ b/src/fri/prover.rs @@ -111,7 +111,7 @@ fn fri_proof_of_work(current_hash: Hash, config: &FriConfig) -> F { ) .to_canonical_u64() .leading_zeros() - >= config.proof_of_work_bits + F::ORDER.leading_zeros() + >= config.proof_of_work_bits + (64 - F::order().bits()) as u32 }) .map(F::from_canonical_u64) .expect("Proof of work failed. This is highly unlikely!") diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index fec8c065..9b3cdd6e 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -61,7 +61,10 @@ impl, const D: usize> CircuitBuilder { inputs.push(proof.pow_witness); let hash = self.hash_n_to_m(inputs, 1, false)[0]; - self.assert_leading_zeros(hash, config.proof_of_work_bits + F::ORDER.leading_zeros()); + self.assert_leading_zeros( + hash, + config.proof_of_work_bits + (64 - F::order().bits()) as u32, + ); } pub fn verify_fri_proof( diff --git a/src/fri/verifier.rs b/src/fri/verifier.rs index 6803c00d..27e775e6 100644 --- a/src/fri/verifier.rs +++ b/src/fri/verifier.rs @@ -59,7 +59,7 @@ fn fri_verify_proof_of_work, const D: usize>( ); ensure!( hash.to_canonical_u64().leading_zeros() - >= config.proof_of_work_bits + F::ORDER.leading_zeros(), + >= config.proof_of_work_bits + (64 - F::order().bits()) as u32, "Invalid proof of work witness." ); diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index b858671d..1e5e6001 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -172,14 +172,12 @@ impl, const D: usize> CircuitBuilder { /// Exponentiate `base` to the power of `exponent`, given by its little-endian bits. pub fn exp_from_bits(&mut self, base: Target, exponent_bits: &[Target]) -> Target { let mut current = base; - let one_ext = self.one_extension(); - let mut product = self.one(); + let one = self.one(); + let mut product = one; for &bit in exponent_bits { - // TODO: Add base field select. - let current_ext = self.convert_to_ext(current); - let multiplicand = self.select(bit, current_ext, one_ext); - product = self.mul(product, multiplicand.0[0]); + let multiplicand = self.select(bit, current, one); + product = self.mul(product, multiplicand); current = self.mul(current, current); } @@ -195,14 +193,12 @@ impl, const D: usize> CircuitBuilder { exponent_bits: impl Iterator>, ) -> Target { let mut current = base; - let one_ext = self.one_extension(); - let mut product = self.one(); + let one = self.one(); + let mut product = one; - for bit in exponent_bits { - let current_ext = self.convert_to_ext(current); - // TODO: Add base field select. - let multiplicand = self.select(*bit.borrow(), one_ext, current_ext); - product = self.mul(product, multiplicand.0[0]); + for &bit in exponent_bits { + let multiplicand = self.select(*bit.borrow(), one, current); + product = self.mul(product, multiplicand); current = self.mul(current, current); } diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index 2f216870..4c4160e1 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -5,6 +5,6 @@ pub mod insert; pub mod interpolation; pub mod polynomial; pub mod range_check; -pub mod rotate; +pub mod select; pub mod split_base; pub(crate) mod split_join; diff --git a/src/gadgets/rotate.rs b/src/gadgets/rotate.rs deleted file mode 100644 index 67677795..00000000 --- a/src/gadgets/rotate.rs +++ /dev/null @@ -1,167 +0,0 @@ -use crate::circuit_builder::CircuitBuilder; -use crate::field::extension_field::target::ExtensionTarget; -use crate::field::extension_field::Extendable; -use crate::target::Target; -use crate::util::log2_ceil; - -impl, const D: usize> CircuitBuilder { - /// Selects `x` or `y` based on `b`, which is assumed to be binary. - /// In particular, this returns `if b { x } else { y }`. - /// Note: This does not range-check `b`. - // TODO: This uses 10 gates per call. If addends are added to `MulExtensionGate`, this will be - // reduced to 2 gates. We could also use a new degree 2 `SelectGate` for this. - // If `num_routed_wire` is larger than 26, we could batch two `select` in one gate. - pub fn select( - &mut self, - b: Target, - x: ExtensionTarget, - y: ExtensionTarget, - ) -> ExtensionTarget { - let b_y_minus_y = self.scalar_mul_sub_extension(b, y, y); - self.scalar_mul_sub_extension(b, x, b_y_minus_y) - } - - /// Left-rotates an array `k` times if `b=1` else return the same array. - pub fn rotate_left_fixed( - &mut self, - b: Target, - k: usize, - v: &[ExtensionTarget], - ) -> Vec> { - let len = v.len(); - debug_assert!(k < len, "Trying to rotate by more than the vector length."); - let mut res = Vec::new(); - - for i in 0..len { - res.push(self.select(b, v[(i + k) % len], v[i])); - } - - res - } - - /// Left-rotates an array `k` times if `b=1` else return the same array. - pub fn rotate_right_fixed( - &mut self, - b: Target, - k: usize, - v: &[ExtensionTarget], - ) -> Vec> { - let len = v.len(); - debug_assert!(k < len, "Trying to rotate by more than the vector length."); - let mut res = Vec::new(); - - for i in 0..len { - res.push(self.select(b, v[(len + i - k) % len], v[i])); - } - - res - } - - /// Left-rotates an vector by the `Target` having bits given in little-endian by `num_rotation_bits`. - pub fn rotate_left_from_bits( - &mut self, - num_rotation_bits: &[Target], - v: &[ExtensionTarget], - ) -> Vec> { - let mut v = v.to_vec(); - - for i in 0..num_rotation_bits.len() { - v = self.rotate_left_fixed(num_rotation_bits[i], 1 << i, &v); - } - - v - } - - pub fn rotate_right_from_bits( - &mut self, - num_rotation_bits: &[Target], - v: &[ExtensionTarget], - ) -> Vec> { - let mut v = v.to_vec(); - - for i in 0..num_rotation_bits.len() { - v = self.rotate_right_fixed(num_rotation_bits[i], 1 << i, &v); - } - - v - } - - /// Left-rotates an array by `num_rotation`. Assumes that `num_rotation` is range-checked to be - /// less than `2^len_bits`. - pub fn rotate_left( - &mut self, - num_rotation: Target, - v: &[ExtensionTarget], - ) -> Vec> { - let len_bits = log2_ceil(v.len()); - let bits = self.split_le(num_rotation, len_bits); - - self.rotate_left_from_bits(&bits, v) - } - - pub fn rotate_right( - &mut self, - num_rotation: Target, - v: &[ExtensionTarget], - ) -> Vec> { - let len_bits = log2_ceil(v.len()); - let bits = self.split_le(num_rotation, len_bits); - - self.rotate_right_from_bits(&bits, v) - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - - use super::*; - use crate::circuit_data::CircuitConfig; - use crate::field::crandall_field::CrandallField; - use crate::field::extension_field::quartic::QuarticCrandallField; - use crate::field::field::Field; - use crate::verifier::verify; - use crate::witness::PartialWitness; - - fn real_rotate( - num_rotation: usize, - v: &[ExtensionTarget], - ) -> Vec> { - let mut res = v.to_vec(); - res.rotate_left(num_rotation); - res - } - - fn test_rotate_given_len(len: usize) -> Result<()> { - type F = CrandallField; - type FF = QuarticCrandallField; - let config = CircuitConfig::large_config(); - let mut builder = CircuitBuilder::::new(config); - let v = (0..len) - .map(|_| builder.constant_extension(FF::rand())) - .collect::>(); - - for i in 0..len { - let it = builder.constant(F::from_canonical_usize(i)); - let rotated = real_rotate(i, &v); - let purported_rotated = builder.rotate_left(it, &v); - - for (x, y) in rotated.into_iter().zip(purported_rotated) { - builder.assert_equal_extension(x, y); - } - } - - let data = builder.build(); - let proof = data.prove(PartialWitness::new())?; - - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - fn test_rotate() -> Result<()> { - for len in 1..5 { - test_rotate_given_len(len)?; - } - Ok(()) - } -} diff --git a/src/gadgets/select.rs b/src/gadgets/select.rs new file mode 100644 index 00000000..bbd36d76 --- /dev/null +++ b/src/gadgets/select.rs @@ -0,0 +1,76 @@ +use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::gates::arithmetic::ArithmeticExtensionGate; +use crate::target::Target; + +impl, const D: usize> CircuitBuilder { + /// Selects `x` or `y` based on `b`, which is assumed to be binary, i.e., this returns `if b { x } else { y }`. + /// This expression is gotten as `bx - (by-y)`, which can be computed with a single `ArithmeticExtensionGate`. + /// Note: This does not range-check `b`. + pub fn select_ext( + &mut self, + b: Target, + x: ExtensionTarget, + y: ExtensionTarget, + ) -> ExtensionTarget { + let b_ext = self.convert_to_ext(b); + let gate = self.num_gates(); + // Holds `by - y`. + let first_out = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_output()); + self.double_arithmetic_extension(F::ONE, F::NEG_ONE, b_ext, y, y, b_ext, x, first_out) + .1 + } + + /// See `select_ext`. + pub fn select(&mut self, b: Target, x: Target, y: Target) -> Target { + let x_ext = self.convert_to_ext(x); + let y_ext = self.convert_to_ext(y); + self.select_ext(b, x_ext, y_ext).to_target_array()[0] + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use super::*; + use crate::circuit_data::CircuitConfig; + use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::quartic::QuarticCrandallField; + use crate::field::field::Field; + use crate::verifier::verify; + use crate::witness::PartialWitness; + + #[test] + fn test_select() -> Result<()> { + type F = CrandallField; + type FF = QuarticCrandallField; + let config = CircuitConfig::large_config(); + let mut builder = CircuitBuilder::::new(config); + let mut pw = PartialWitness::new(); + + let (x, y) = (FF::rand(), FF::rand()); + let xt = builder.add_virtual_extension_target(); + let yt = builder.add_virtual_extension_target(); + let truet = builder.add_virtual_target(); + let falset = builder.add_virtual_target(); + + pw.set_extension_target(xt, x); + pw.set_extension_target(yt, y); + pw.set_target(truet, F::ONE); + pw.set_target(falset, F::ZERO); + + let should_be_x = builder.select_ext(truet, xt, yt); + let should_be_y = builder.select_ext(falset, xt, yt); + + builder.assert_equal_extension(should_be_x, xt); + builder.assert_equal_extension(should_be_y, yt); + + let data = builder.build(); + let proof = data.prove(pw)?; + + verify(proof, &data.verifier_only, &data.common) + } +} diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index cf39e09b..a3739ee5 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -2,11 +2,11 @@ use std::ops::Range; use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; -use crate::field::extension_field::Extendable; +use crate::field::extension_field::{Extendable, FieldExtension}; use crate::gates::gate::{Gate, GateRef}; use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::target::Target; -use crate::vars::{EvaluationTargets, EvaluationVars}; +use crate::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::witness::PartialWitness; /// A gate which can a linear combination `c0*x*y+c1*z` twice with the same `x`. @@ -74,6 +74,31 @@ impl, const D: usize> Gate for ArithmeticExtensionGate constraints } + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let const_0 = vars.local_constants[0]; + let const_1 = vars.local_constants[1]; + + let first_multiplicand_0 = vars.get_local_ext(Self::wires_first_multiplicand_0()); + let first_multiplicand_1 = vars.get_local_ext(Self::wires_first_multiplicand_1()); + let first_addend = vars.get_local_ext(Self::wires_first_addend()); + let second_multiplicand_0 = vars.get_local_ext(Self::wires_second_multiplicand_0()); + let second_multiplicand_1 = vars.get_local_ext(Self::wires_second_multiplicand_1()); + let second_addend = vars.get_local_ext(Self::wires_second_addend()); + let first_output = vars.get_local_ext(Self::wires_first_output()); + let second_output = vars.get_local_ext(Self::wires_second_output()); + + let first_computed_output = first_multiplicand_0 * first_multiplicand_1 * const_0.into() + + first_addend * const_1.into(); + let second_computed_output = second_multiplicand_0 * second_multiplicand_1 * const_0.into() + + second_addend * const_1.into(); + + let mut constraints = (first_output - first_computed_output) + .to_basefield_array() + .to_vec(); + constraints.extend((second_output - second_computed_output).to_basefield_array()); + constraints + } + fn eval_unfiltered_recursively( &self, builder: &mut CircuitBuilder, diff --git a/src/gates/base_sum.rs b/src/gates/base_sum.rs index 8ad189ee..b6645959 100644 --- a/src/gates/base_sum.rs +++ b/src/gates/base_sum.rs @@ -8,7 +8,7 @@ use crate::gates::gate::{Gate, GateRef}; use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::plonk_common::{reduce_with_powers, reduce_with_powers_recursive}; use crate::target::Target; -use crate::vars::{EvaluationTargets, EvaluationVars}; +use crate::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::witness::PartialWitness; /// A gate which can decompose a number into base B little-endian limbs, @@ -57,6 +57,20 @@ impl, const D: usize, const B: usize> Gate for BaseSumGat constraints } + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let sum = vars.local_wires[Self::WIRE_SUM]; + let reversed_sum = vars.local_wires[Self::WIRE_REVERSED_SUM]; + let mut limbs = vars.local_wires[self.limbs()].to_vec(); + let computed_sum = reduce_with_powers(&limbs, F::from_canonical_usize(B)); + limbs.reverse(); + let computed_reversed_sum = reduce_with_powers(&limbs, F::from_canonical_usize(B)); + let mut constraints = vec![computed_sum - sum, computed_reversed_sum - reversed_sum]; + for limb in limbs { + constraints.push((0..B).map(|i| limb - F::from_canonical_usize(i)).product()); + } + constraints + } + fn eval_unfiltered_recursively( &self, builder: &mut CircuitBuilder, diff --git a/src/gates/constant.rs b/src/gates/constant.rs index 4049d058..4a5c4373 100644 --- a/src/gates/constant.rs +++ b/src/gates/constant.rs @@ -5,7 +5,7 @@ use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::target::Target; -use crate::vars::{EvaluationTargets, EvaluationVars}; +use crate::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::wire::Wire; use crate::witness::PartialWitness; @@ -33,6 +33,12 @@ impl, const D: usize> Gate for ConstantGate { vec![output - input] } + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let input = vars.local_constants[Self::CONST_INPUT]; + let output = vars.local_wires[Self::WIRE_OUTPUT]; + vec![output - input] + } + fn eval_unfiltered_recursively( &self, builder: &mut CircuitBuilder, diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 04c9d54a..b4fa6345 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -8,7 +8,7 @@ use crate::gates::gate::{Gate, GateRef}; use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::gmimc::gmimc_automatic_constants; use crate::target::Target; -use crate::vars::{EvaluationTargets, EvaluationVars}; +use crate::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::wire::Wire; use crate::witness::PartialWitness; @@ -112,6 +112,55 @@ impl, const D: usize, const R: usize> Gate for GMiMCGate< constraints } + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + // Assert that `swap` is binary. + let swap = vars.local_wires[Self::WIRE_SWAP]; + constraints.push(swap * (swap - F::ONE)); + + let old_index_acc = vars.local_wires[Self::WIRE_INDEX_ACCUMULATOR_OLD]; + let new_index_acc = vars.local_wires[Self::WIRE_INDEX_ACCUMULATOR_NEW]; + let computed_new_index_acc = F::TWO * old_index_acc + swap; + constraints.push(computed_new_index_acc - new_index_acc); + + let mut state = Vec::with_capacity(12); + for i in 0..4 { + let a = vars.local_wires[i]; + let b = vars.local_wires[i + 4]; + state.push(a + swap * (b - a)); + } + for i in 0..4 { + let a = vars.local_wires[i + 4]; + let b = vars.local_wires[i]; + state.push(a + swap * (b - a)); + } + for i in 8..12 { + state.push(vars.local_wires[i]); + } + + // Value that is implicitly added to each element. + // See https://affine.group/2020/02/starkware-challenge + let mut addition_buffer = F::ZERO; + + for r in 0..R { + let active = r % W; + let cubing_input = state[active] + addition_buffer + self.constants[r].into(); + let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)]; + constraints.push(cubing_input - cubing_input_wire); + let f = cubing_input_wire.cube(); + addition_buffer += f; + state[active] -= f; + } + + for i in 0..W { + state[i] += addition_buffer; + constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]); + } + + constraints + } + fn eval_unfiltered_recursively( &self, builder: &mut CircuitBuilder, diff --git a/src/gates/insertion.rs b/src/gates/insertion.rs index 1bc0b454..4bfe97a9 100644 --- a/src/gates/insertion.rs +++ b/src/gates/insertion.rs @@ -9,7 +9,7 @@ use crate::field::field::Field; use crate::gates::gate::{Gate, GateRef}; use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; use crate::target::Target; -use crate::vars::{EvaluationTargets, EvaluationVars}; +use crate::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::wire::Wire; use crate::witness::PartialWitness; @@ -114,6 +114,44 @@ impl, const D: usize> Gate for InsertionGate { constraints } + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let insertion_index = vars.local_wires[self.wires_insertion_index()]; + let list_items = (0..self.vec_size) + .map(|i| vars.get_local_ext(self.wires_original_list_item(i))) + .collect::>(); + let output_list_items = (0..=self.vec_size) + .map(|i| vars.get_local_ext(self.wires_output_list_item(i))) + .collect::>(); + let element_to_insert = vars.get_local_ext(self.wires_element_to_insert()); + + let mut constraints = Vec::new(); + let mut already_inserted = F::ZERO; + for r in 0..=self.vec_size { + let cur_index = F::from_canonical_usize(r); + let difference = cur_index - insertion_index; + let equality_dummy = vars.local_wires[self.wires_equality_dummy_for_round_r(r)]; + let insert_here = vars.local_wires[self.wires_insert_here_for_round_r(r)]; + + // The two equality constraints. + constraints.push(difference * equality_dummy - (F::ONE - insert_here)); + constraints.push(insert_here * difference); + + let mut new_item = element_to_insert * insert_here.into(); + if r > 0 { + new_item += list_items[r - 1] * already_inserted.into(); + } + already_inserted += insert_here; + if r < self.vec_size { + new_item += list_items[r] * (F::ONE - already_inserted).into(); + } + + // Output constraint. + constraints.extend((new_item - output_list_items[r]).to_basefield_array()); + } + + constraints + } + fn eval_unfiltered_recursively( &self, builder: &mut CircuitBuilder, diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index 17d34e3a..6d6594b4 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -10,8 +10,9 @@ use crate::field::interpolation::interpolant; use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; use crate::gates::gate::{Gate, GateRef}; use crate::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::polynomial::polynomial::PolynomialCoeffs; use crate::target::Target; -use crate::vars::{EvaluationTargets, EvaluationVars}; +use crate::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::wire::Wire; use crate::witness::PartialWitness; @@ -121,6 +122,29 @@ impl, const D: usize> Gate for InterpolationGate { constraints } + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let coeffs = (0..self.num_points) + .map(|i| vars.get_local_ext(self.wires_coeff(i))) + .collect(); + let interpolant = PolynomialCoeffs::new(coeffs); + + for i in 0..self.num_points { + let point = vars.local_wires[self.wire_point(i)]; + let value = vars.get_local_ext(self.wires_value(i)); + let computed_value = interpolant.eval(point.into()); + constraints.extend(&(value - computed_value).to_basefield_array()); + } + + let evaluation_point = vars.get_local_ext(self.wires_evaluation_point()); + let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); + let computed_evaluation_value = interpolant.eval(evaluation_point); + constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); + + constraints + } + fn eval_unfiltered_recursively( &self, builder: &mut CircuitBuilder, @@ -313,31 +337,15 @@ mod tests { points: Vec, eval_point: FF, ) -> Vec { - let mut v = vec![F::ZERO; num_points * 5 + (coeffs.len() + 3) * D]; + let mut v = Vec::new(); + v.extend_from_slice(&points); for j in 0..num_points { - v[j] = points[j]; - } - for j in 0..num_points { - for i in 0..D { - v[num_points + D * j + i] = >::to_basefield_array( - &coeffs.eval(points[j].into()), - )[i]; - } - } - for i in 0..D { - v[num_points * 5 + i] = - >::to_basefield_array(&eval_point)[i]; - } - for i in 0..D { - v[num_points * 5 + D + i] = - >::to_basefield_array(&coeffs.eval(eval_point))[i]; + v.extend(coeffs.eval(points[j].into()).0); } + v.extend(eval_point.0); + v.extend(coeffs.eval(eval_point).0); for i in 0..coeffs.len() { - for (j, input) in - (0..D).zip(num_points * 5 + (2 + i) * D..num_points * 5 + (3 + i) * D) - { - v[input] = >::to_basefield_array(&coeffs.coeffs[i])[j]; - } + v.extend(coeffs.coeffs[i].0); } v.iter().map(|&x| x.into()).collect::>() } diff --git a/src/gates/noop.rs b/src/gates/noop.rs index a12df932..c27b22bf 100644 --- a/src/gates/noop.rs +++ b/src/gates/noop.rs @@ -3,7 +3,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::gates::gate::{Gate, GateRef}; use crate::generator::WitnessGenerator; -use crate::vars::{EvaluationTargets, EvaluationVars}; +use crate::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; /// A gate which does nothing. pub struct NoopGate; @@ -23,6 +23,10 @@ impl, const D: usize> Gate for NoopGate { Vec::new() } + fn eval_unfiltered_base(&self, _vars: EvaluationVarsBase) -> Vec { + Vec::new() + } + fn eval_unfiltered_recursively( &self, _builder: &mut CircuitBuilder, diff --git a/src/gates/public_input.rs b/src/gates/public_input.rs index a86b78d5..e1ce9271 100644 --- a/src/gates/public_input.rs +++ b/src/gates/public_input.rs @@ -5,7 +5,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::gates::gate::{Gate, GateRef}; use crate::generator::WitnessGenerator; -use crate::vars::{EvaluationTargets, EvaluationVars}; +use crate::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; /// A gate whose first four wires will be equal to a hash of public inputs. pub struct PublicInputGate; @@ -32,6 +32,13 @@ impl, const D: usize> Gate for PublicInputGate { .collect() } + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + Self::wires_public_inputs_hash() + .zip(vars.public_inputs_hash.elements) + .map(|(wire, hash_part)| vars.local_wires[wire] - hash_part) + .collect() + } + fn eval_unfiltered_recursively( &self, builder: &mut CircuitBuilder, diff --git a/src/prover.rs b/src/prover.rs index 59b3cd3d..7d209667 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -78,9 +78,10 @@ pub(crate) fn prove, const D: usize>( ); let mut challenger = Challenger::new(); + // Observe the instance. - // TODO: Need to include public inputs as well. challenger.observe_hash(&common_data.circuit_digest); + challenger.observe_hash(&public_inputs_hash); challenger.observe_hash(&wires_commitment.merkle_tree.root); let betas = challenger.get_n_challenges(num_challenges); diff --git a/src/recursive_verifier.rs b/src/recursive_verifier.rs index 0e4080ea..cc92ccdb 100644 --- a/src/recursive_verifier.rs +++ b/src/recursive_verifier.rs @@ -28,18 +28,20 @@ impl, const D: usize> CircuitBuilder { } = proof_with_pis; let one = self.one_extension(); - let public_inputs_hash = &self.hash_n_to_hash(public_inputs, true); - let num_challenges = inner_config.num_challenges; + let public_inputs_hash = &self.hash_n_to_hash(public_inputs, true); + let mut challenger = RecursiveChallenger::new(self); let (betas, gammas, alphas, zeta) = context!(self, "observe proof and generates challenges", { + // Observe the instance. let digest = HashTarget::from_vec( self.constants(&inner_common_data.circuit_digest.elements), ); challenger.observe_hash(&digest); + challenger.observe_hash(&public_inputs_hash); challenger.observe_hash(&proof.wires_root); let betas = challenger.get_n_challenges(self, num_challenges); diff --git a/src/vars.rs b/src/vars.rs index 8e98d41f..66ce2efb 100644 --- a/src/vars.rs +++ b/src/vars.rs @@ -3,7 +3,7 @@ use std::ops::Range; use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; -use crate::field::extension_field::Extendable; +use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; use crate::proof::{Hash, HashTarget}; @@ -37,6 +37,15 @@ impl<'a, F: Extendable, const D: usize> EvaluationVars<'a, F, D> { } impl<'a, F: Field> EvaluationVarsBase<'a, F> { + pub fn get_local_ext(&self, wire_range: Range) -> F::Extension + where + F: Extendable, + { + debug_assert_eq!(wire_range.len(), D); + let arr = self.local_wires[wire_range].try_into().unwrap(); + F::Extension::from_basefield_array(arr) + } + pub fn remove_prefix(&mut self, prefix: &[bool]) { self.local_constants = &self.local_constants[prefix.len()..]; } diff --git a/src/verifier.rs b/src/verifier.rs index 878b630a..d8af4cb4 100644 --- a/src/verifier.rs +++ b/src/verifier.rs @@ -25,9 +25,10 @@ pub(crate) fn verify, const D: usize>( let public_inputs_hash = &hash_n_to_hash(public_inputs, true); let mut challenger = Challenger::new(); + // Observe the instance. - // TODO: Need to include public inputs as well. challenger.observe_hash(&common_data.circuit_digest); + challenger.observe_hash(&public_inputs_hash); challenger.observe_hash(&proof.wires_root); let betas = challenger.get_n_challenges(num_challenges);