diff --git a/src/field/crandall_field.rs b/src/field/crandall_field.rs index 1425d68e..548c49b8 100644 --- a/src/field/crandall_field.rs +++ b/src/field/crandall_field.rs @@ -55,7 +55,7 @@ impl Field for CrandallField { const MULTIPLICATIVE_SUBGROUP_GENERATOR: Self = Self(5); // TODO: Double check. #[inline] - fn sq(&self) -> Self { + fn square(&self) -> Self { *self * *self } @@ -242,4 +242,7 @@ fn split(x: u128) -> (u64, u64) { #[cfg(test)] mod tests { + use crate::test_arithmetic; + + test_arithmetic!(crate::CrandallField); } diff --git a/src/field/fft.rs b/src/field/fft.rs index df21856f..d3c98766 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -153,79 +153,69 @@ pub(crate) fn coset_ifft(poly: PolynomialValues, shift: F) -> Polyn coeffs } -// #[cfg(test)] -// mod tests { -// use crate::{Bls12377Scalar, fft_precompute, fft_with_precomputation, CrandallField, ifft_with_precomputation_power_of_2}; -// use crate::fft::{log2_strict, reverse_bits, reverse_index_bits}; -// use crate::util::log2_ceil; -// -// #[test] -// fn fft_and_ifft() { -// let degree = 200; -// let degree_padded = log2_ceil(degree); -// let mut coefficients = Vec::new(); -// for i in 0..degree { -// coefficients.push(Bls12377Scalar::from_canonical_usize(i * 1337 % 100)); -// } -// -// let precomputation = fft_precompute(degree); -// let points = fft_with_precomputation(&coefficients, &precomputation); -// assert_eq!(points, evaluate_naive(&coefficients)); -// -// let interpolated_coefficients = -// ifft_with_precomputation_power_of_2(&points, &precomputation); -// for i in 0..degree { -// assert_eq!(interpolated_coefficients[i], coefficients[i]); -// } -// for i in degree..degree_padded { -// assert_eq!(interpolated_coefficients[i], Bls12377Scalar::ZERO); -// } -// } -// -// #[test] -// fn test_reverse_bits() { -// assert_eq!(reverse_bits(0b00110101, 8), 0b10101100); -// assert_eq!(reverse_index_bits(vec!["a", "b"]), vec!["a", "b"]); -// assert_eq!( -// reverse_index_bits(vec!["a", "b", "c", "d"]), -// vec!["a", "c", "b", "d"] -// ); -// } -// -// fn evaluate_naive(coefficients: &[CrandallField]) -> Vec { -// let degree = coefficients.len(); -// let degree_padded = 1 << log2_ceil(degree); -// -// let mut coefficients_padded = Vec::with_capacity(degree_padded); -// for c in coefficients { -// coefficients_padded.push(*c); -// } -// for _i in degree..degree_padded { -// coefficients_padded.push(F::ZERO); -// } -// evaluate_naive_power_of_2(&coefficients_padded) -// } -// -// fn evaluate_naive_power_of_2(coefficients: &[CrandallField]) -> Vec { -// let degree = coefficients.len(); -// let degree_pow = log2_strict(degree); -// -// let g = F::primitive_root_of_unity(degree_pow); -// let powers_of_g = F::cyclic_subgroup_known_order(g, degree); -// -// powers_of_g -// .into_iter() -// .map(|x| evaluate_at_naive(&coefficients, x)) -// .collect() -// } -// -// fn evaluate_at_naive(coefficients: &[CrandallField], point: F) -> F { -// let mut sum = F::ZERO; -// let mut point_power = F::ONE; -// for &c in coefficients { -// sum = sum + c * point_power; -// point_power = point_power * point; -// } -// sum -// } -// } +#[cfg(test)] +mod tests { + use crate::util::{log2_ceil, log2_strict}; + use crate::field::fft::{ifft, fft}; + use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; + use crate::field::field::Field; + use crate::field::crandall_field::CrandallField; + + #[test] + fn fft_and_ifft() { + type F = CrandallField; + let degree = 200; + let degree_padded = log2_ceil(degree); + let mut coefficients = Vec::new(); + for i in 0..degree { + coefficients.push(F::from_canonical_usize(i * 1337 % 100)); + } + let coefficients = PolynomialCoeffs::pad(coefficients); + + let points = fft(coefficients.clone()); + assert_eq!(points, evaluate_naive(&coefficients)); + + let interpolated_coefficients = ifft(points); + for i in 0..degree { + assert_eq!(interpolated_coefficients.coeffs[i], coefficients.coeffs[i]); + } + for i in degree..degree_padded { + assert_eq!(interpolated_coefficients.coeffs[i], F::ZERO); + } + } + + fn evaluate_naive(coefficients: &PolynomialCoeffs) -> PolynomialValues { + let degree = coefficients.len(); + let degree_padded = 1 << log2_ceil(degree); + + let mut coefficients_padded = coefficients.clone(); + for _i in degree..degree_padded { + coefficients_padded.coeffs.push(F::ZERO); + } + evaluate_naive_power_of_2(&coefficients_padded) + } + + fn evaluate_naive_power_of_2(coefficients: &PolynomialCoeffs) -> PolynomialValues { + let degree = coefficients.len(); + let degree_pow = log2_strict(degree); + + let g = F::primitive_root_of_unity(degree_pow); + let powers_of_g = F::cyclic_subgroup_known_order(g, degree); + + let values = powers_of_g + .into_iter() + .map(|x| evaluate_at_naive(&coefficients, x)) + .collect(); + PolynomialValues::new(values) + } + + fn evaluate_at_naive(coefficients: &PolynomialCoeffs, point: F) -> F { + let mut sum = F::ZERO; + let mut point_power = F::ONE; + for &c in &coefficients.coeffs { + sum = sum + c * point_power; + point_power = point_power * point; + } + sum + } +} diff --git a/src/field/field.rs b/src/field/field.rs index 59c94c64..239b0160 100644 --- a/src/field/field.rs +++ b/src/field/field.rs @@ -2,6 +2,7 @@ use std::fmt::{Debug, Display}; use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use rand::Rng; use rand::rngs::OsRng; +use crate::util::bits_u64; /// A finite field with prime order less than 2^64. pub trait Field: 'static @@ -36,7 +37,7 @@ pub trait Field: 'static *self == Self::ONE } - fn sq(&self) -> Self { + fn square(&self) -> Self { *self * *self } @@ -93,7 +94,7 @@ pub trait Field: 'static } fn bits(&self) -> usize { - 64 - self.to_canonical_u64().leading_zeros() as usize + bits_u64(self.to_canonical_u64()) } fn exp(&self, power: Self) -> Self { @@ -104,7 +105,7 @@ pub trait Field: 'static if (power.to_canonical_u64() >> j & 1) != 0 { product = product * current; } - current = current.sq(); + current = current.square(); } product } diff --git a/src/field/field_testing.rs b/src/field/field_testing.rs new file mode 100644 index 00000000..77bc3448 --- /dev/null +++ b/src/field/field_testing.rs @@ -0,0 +1,221 @@ +use crate::field::field::Field; +use crate::util::{bits_u64, ceil_div_usize}; + +/// Generates a series of non-negative integers less than +/// `modulus` which cover a range of values and which will +/// generate lots of carries, especially at `word_bits` word +/// boundaries. +pub fn test_inputs(modulus: u64, word_bits: usize) -> Vec { + assert!(word_bits == 32 || word_bits == 64); + let modwords = ceil_div_usize(bits_u64(modulus), word_bits); + // Start with basic set close to zero: 0 .. 10 + const BIGGEST_SMALL: u32 = 10; + let smalls: Vec<_> = (0..BIGGEST_SMALL).map(u64::from).collect(); + // ... and close to MAX: MAX - x + let word_max = (1u64 << word_bits) - 1; + let bigs = smalls.iter().map(|x| &word_max - x).collect(); + let one_words = [smalls, bigs].concat(); + // For each of the one word inputs above, create a new one at word i. + // TODO: Create all possible `modwords` combinations of those + let multiple_words = (1..modwords) + .flat_map(|i| { + one_words + .iter() + .map(|x| x << (word_bits * i)) + .collect::>() + }) + .collect(); + let basic_inputs: Vec = [one_words, multiple_words].concat(); + + // Biggest value that will fit in `modwords` words + // Inputs 'difference from' maximum value + let diff_max = basic_inputs + .iter() + .map(|&x| u64::MAX - x) + .filter(|&x| x < modulus) + .collect(); + // Inputs 'difference from' modulus value + let diff_mod = basic_inputs + .iter() + .filter(|x| **x < modulus && **x != 0) + .map(|&x| modulus - x) + .collect(); + let basics = basic_inputs + .into_iter() + .filter(|&x| x < modulus) + .collect::>(); + [basics, diff_max, diff_mod].concat() + + // // There should be a nicer way to express the code above; something + // // like this (and removing collect() calls from diff_max and diff_mod): + // basic_inputs.into_iter() + // .chain(diff_max) + // .chain(diff_mod) + // .filter(|x| x < &modulus) + // .collect() +} + + +/// Apply the unary functions `op` and `expected_op` +/// coordinate-wise to the inputs from `test_inputs(modulus, +/// word_bits)` and panic if the two resulting vectors differ. +pub fn run_unaryop_test_cases( + modulus: u64, + word_bits: usize, + op: UnaryOp, + expected_op: ExpectedOp, +) where + F: Field, + UnaryOp: Fn(F) -> F, + ExpectedOp: Fn(u64) -> u64, +{ + let inputs = test_inputs(modulus, word_bits); + let expected = inputs.iter().map(|&x| expected_op(x)); + let output = inputs + .iter() + .map(|&x| op(F::from_canonical_u64(x)).to_canonical_u64()); + // Compare expected outputs with actual outputs + assert!( + output.zip(expected).all(|(x, y)| x == y), + "output differs from expected" + ); +} + +/// Apply the binary functions `op` and `expected_op` to each pair +/// in `zip(inputs, rotate_right(inputs, i))` where `inputs` is +/// `test_inputs(modulus, word_bits)` and `i` ranges from 0 to +/// `inputs.len()`. Panic if the two functions ever give +/// different answers. +pub fn run_binaryop_test_cases( + modulus: u64, + word_bits: usize, + op: BinaryOp, + expected_op: ExpectedOp, +) where + F: Field, + BinaryOp: Fn(F, F) -> F, + ExpectedOp: Fn(u64, u64) -> u64, +{ + let inputs = test_inputs(modulus, word_bits); + + for i in 0..inputs.len() { + // Iterator over inputs rotated right by i places. Since + // cycle().skip(i) rotates left by i, we need to rotate by + // n_input_elts - i. + let shifted_inputs = inputs.iter().cycle().skip(inputs.len() - i); + // Calculate pointwise operations + let expected = inputs + .iter() + .zip(shifted_inputs.clone()) + .map(|(x, y)| expected_op(x.clone(), y.clone())); + let output = inputs.iter().zip(shifted_inputs).map(|(&x, &y)| { + op(F::from_canonical_u64(x), F::from_canonical_u64(y)).to_canonical_u64() + }); + // Compare expected outputs with actual outputs + assert!( + output.zip(expected).all(|(x, y)| x == y), + "output differs from expected at rotation {}", + i + ); + } +} + +#[macro_export] +macro_rules! test_arithmetic { + ($field:ty) => { + mod arithmetic { + use crate::{Field}; + use std::ops::{Add, Div, Mul, Neg, Sub}; + + // Can be 32 or 64; doesn't have to be computer's actual word + // bits. Choosing 32 gives more tests... + const WORD_BITS: usize = 32; + + #[test] + fn arithmetic_addition() { + let modulus = <$field>::ORDER; + crate::field::field_testing::run_binaryop_test_cases(modulus, WORD_BITS, <$field>::add, |x, y| { + let (z, over) = x.overflowing_add(y); + if over { + z.overflowing_sub(modulus).0 + } else if z >= modulus { + z - modulus + } else { + z + } + }) + } + + #[test] + fn arithmetic_subtraction() { + let modulus = <$field>::ORDER; + crate::field::field_testing::run_binaryop_test_cases(modulus, WORD_BITS, <$field>::sub, |x, y| { + if x >= y { + x - y + } else { + &modulus - y + x + } + }) + } + + #[test] + fn arithmetic_negation() { + let modulus = <$field>::ORDER; + crate::field::field_testing::run_unaryop_test_cases(modulus, WORD_BITS, <$field>::neg, |x| { + if x == 0 { + 0 + } else { + modulus - x + } + }) + } + + #[test] + fn arithmetic_multiplication() { + let modulus = <$field>::ORDER; + crate::field::field_testing::run_binaryop_test_cases(modulus, WORD_BITS, <$field>::mul, |x, y| { + x * y % modulus + }) + } + + #[test] + fn arithmetic_square() { + let modulus = <$field>::ORDER; + crate::field::field_testing::run_unaryop_test_cases( + modulus, WORD_BITS, + |x: $field| x.square(), + |x| x * x % modulus) + } + + // #[test] + // #[ignore] + // fn arithmetic_division() { + // // This test takes ages to finish so is #[ignore]d by default. + // // TODO: Re-enable and reimplement when + // // https://github.com/rust-num/num-bigint/issues/60 is finally resolved. + // let modulus = <$field>::ORDER; + // crate::field::field_testing::run_binaryop_test_cases( + // modulus, + // WORD_BITS, + // // Need to help the compiler infer the type of y here + // |x: $field, y: $field| { + // // TODO: Work out how to check that div() panics + // // appropriately when given a zero divisor. + // if !y.is_zero() { + // <$field>::div(x, y) + // } else { + // <$field>::ZERO + // } + // }, + // |x, y| { + // // yinv = y^-1 (mod modulus) + // let exp = modulus - 2u64; + // let yinv = y.modpow(exp, modulus); + // // returns 0 if y was 0 + // x * yinv % modulus + // }, + // ) + // } + } + }; +} diff --git a/src/field/mod.rs b/src/field/mod.rs index 027dd742..38fab717 100644 --- a/src/field/mod.rs +++ b/src/field/mod.rs @@ -2,3 +2,6 @@ pub(crate) mod crandall_field; pub(crate) mod field; pub(crate) mod field_search; pub(crate) mod fft; + +#[cfg(test)] +mod field_testing; diff --git a/src/polynomial/polynomial.rs b/src/polynomial/polynomial.rs index 36a7a944..6a96b072 100644 --- a/src/polynomial/polynomial.rs +++ b/src/polynomial/polynomial.rs @@ -6,7 +6,7 @@ use crate::util::log2_strict; /// /// The points are implicitly `g^i`, where `g` generates the subgroup whose size equals the number /// of points. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub(crate) struct PolynomialValues { pub(crate) values: Vec, } @@ -42,7 +42,7 @@ impl PolynomialValues { } /// A polynomial in coefficient form. The number of coefficients must be a power of two. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub(crate) struct PolynomialCoeffs { pub(crate) coeffs: Vec, } @@ -53,6 +53,13 @@ impl PolynomialCoeffs { PolynomialCoeffs { coeffs } } + pub(crate) fn pad(mut coeffs: Vec) -> Self { + while !coeffs.len().is_power_of_two() { + coeffs.push(F::ZERO); + } + PolynomialCoeffs { coeffs } + } + pub(crate) fn zero(len: usize) -> Self { Self::new(vec![F::ZERO; len]) } diff --git a/src/rescue.rs b/src/rescue.rs index a2346bfe..f97e0782 100644 --- a/src/rescue.rs +++ b/src/rescue.rs @@ -95,7 +95,7 @@ fn sbox_a(x: F) -> F { if ((EXP >> i) & 1) != 0 { product = product * current; } - current = current.sq(); + current = current.square(); } product } @@ -103,7 +103,7 @@ fn sbox_a(x: F) -> F { #[inline(always)] fn sbox_b(x: F) -> F { // x^5 - let x2 = x.sq(); + let x2 = x.square(); let x3 = x2 * x; x2 * x3 } diff --git a/src/util.rs b/src/util.rs index e12cf8a6..9e8e5849 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,6 +1,10 @@ use crate::field::field::Field; use crate::polynomial::polynomial::PolynomialValues; +pub(crate) fn bits_u64(n: u64) -> usize { + (64 - n.leading_zeros()) as usize +} + pub(crate) fn ceil_div_usize(a: usize, b: usize) -> usize { (a + b - 1) / b }