diff --git a/src/field/crandall_field.rs b/src/field/crandall_field.rs index 9868336a..b3d96bbe 100644 --- a/src/field/crandall_field.rs +++ b/src/field/crandall_field.rs @@ -242,10 +242,6 @@ impl Field for CrandallField { Self(n) } - fn from_canonical_biguint(n: BigUint) -> Self { - Self(n.iter_u64_digits().next().unwrap_or(0)) - } - #[inline] fn from_noncanonical_u128(n: u128) -> Self { reduce128(n) @@ -361,10 +357,6 @@ impl PrimeField for CrandallField { fn to_noncanonical_u64(&self) -> u64 { self.0 } - - fn to_canonical_biguint(&self) -> BigUint { - BigUint::from(self.to_canonical_u64()) - } } impl Neg for CrandallField { diff --git a/src/field/extension_field/quadratic.rs b/src/field/extension_field/quadratic.rs index 2a1fef70..c0febe33 100644 --- a/src/field/extension_field/quadratic.rs +++ b/src/field/extension_field/quadratic.rs @@ -93,16 +93,6 @@ impl Field for QuadraticCrandallField { >::BaseField::from_canonical_u64(n).into() } - fn from_canonical_biguint(n: BigUint) -> Self { - let smaller = n.clone() % Self::CHARACTERISTIC; - let larger = n.clone() / Self::CHARACTERISTIC; - - Self([ - >::BaseField::from_canonical_biguint(smaller), - >::BaseField::from_canonical_biguint(larger), - ]) - } - fn from_noncanonical_u128(n: u128) -> Self { >::BaseField::from_noncanonical_u128(n).into() } diff --git a/src/field/extension_field/quartic.rs b/src/field/extension_field/quartic.rs index f37cfb92..1e92a8fa 100644 --- a/src/field/extension_field/quartic.rs +++ b/src/field/extension_field/quartic.rs @@ -126,23 +126,6 @@ impl Field for QuarticCrandallField { >::BaseField::from_canonical_u64(n).into() } - fn from_canonical_biguint(n: BigUint) -> Self { - let first = &n % Self::CHARACTERISTIC; - let mut remaining = &n / Self::CHARACTERISTIC; - let second = &remaining % Self::CHARACTERISTIC; - remaining = remaining / Self::CHARACTERISTIC; - let third = &remaining % Self::CHARACTERISTIC; - remaining = remaining / Self::CHARACTERISTIC; - let fourth = &remaining % Self::CHARACTERISTIC; - - Self([ - >::BaseField::from_canonical_biguint(first), - >::BaseField::from_canonical_biguint(second), - >::BaseField::from_canonical_biguint(third), - >::BaseField::from_canonical_biguint(fourth), - ]) - } - fn from_noncanonical_u128(n: u128) -> Self { >::BaseField::from_noncanonical_u128(n).into() } diff --git a/src/field/field_testing.rs b/src/field/field_testing.rs index 658fc65a..7bf7702e 100644 --- a/src/field/field_testing.rs +++ b/src/field/field_testing.rs @@ -1,156 +1,10 @@ -use num::{bigint::BigUint, Zero}; - use crate::field::field_types::PrimeField; -use crate::util::ceil_div_usize; - -/// Generates a series of non-negative integers less than -/// `modulus` which cover a range of values and which will -/// generate lots of carries, especially at `word_bits` word -/// boundaries. -pub fn test_inputs(modulus: BigUint, word_bits: usize) -> Vec { - //assert!(word_bits == 32 || word_bits == 64); - let modwords = ceil_div_usize(modulus.bits() as usize, word_bits); - // Start with basic set close to zero: 0 .. 10 - const BIGGEST_SMALL: u32 = 10; - let smalls: Vec<_> = (0..BIGGEST_SMALL).map(BigUint::from).collect(); - // ... and close to MAX: MAX - x - let word_max = (BigUint::from(1u32) << word_bits) - 1u32; - let multiple_words_max = (BigUint::from(1u32) << modwords * word_bits) - 1u32; - let bigs = smalls.iter().map(|x| &word_max - x).collect(); - let one_words = [smalls, bigs].concat(); - // For each of the one word inputs above, create a new one at word i. - // TODO: Create all possible `modwords` combinations of those - let multiple_words = (1..modwords) - .flat_map(|i| { - one_words - .iter() - .map(|x| x << (word_bits * i)) - .collect::>() - }) - .collect(); - let basic_inputs: Vec = [one_words, multiple_words].concat(); - - // Biggest value that will fit in `modwords` words - // Inputs 'difference from' maximum value - let diff_max = basic_inputs - .iter() - .map(|x| &multiple_words_max - x) - .filter(|x| x < &modulus) - .collect(); - // Inputs 'difference from' modulus value - let diff_mod = basic_inputs - .iter() - .filter(|&x| x < &modulus && !x.is_zero()) - .map(|x| &modulus - x) - .collect(); - let basics = basic_inputs - .into_iter() - .filter(|x| x < &modulus) - .collect::>(); - [basics, diff_max, diff_mod].concat() - - // // There should be a nicer way to express the code above; something - // // like this (and removing collect() calls from diff_max and diff_mod): - // basic_inputs.into_iter() - // .chain(diff_max) - // .chain(diff_mod) - // .filter(|x| x < &modulus) - // .collect() -} - -/// Apply the unary functions `op` and `expected_op` -/// coordinate-wise to the inputs from `test_inputs(modulus, -/// word_bits)` and panic if the two resulting vectors differ. -pub fn run_unaryop_test_cases( - modulus: BigUint, - word_bits: usize, - op: UnaryOp, - expected_op: ExpectedOp, -) where - F: PrimeField, - UnaryOp: Fn(F) -> F, - ExpectedOp: Fn(BigUint) -> BigUint, -{ - let inputs = test_inputs(modulus, word_bits); - let expected: Vec<_> = inputs.iter().map(|x| expected_op(x.clone())).collect(); - let output: Vec<_> = inputs - .iter() - .cloned() - .map(|x| op(F::from_canonical_biguint(x)).to_canonical_biguint()) - .collect(); - // Compare expected outputs with actual outputs - for i in 0..inputs.len() { - assert_eq!( - output[i], expected[i], - "Expected {}, got {} for input {}", - expected[i], output[i], inputs[i] - ); - } -} - -/// Apply the binary functions `op` and `expected_op` to each pair -/// in `zip(inputs, rotate_right(inputs, i))` where `inputs` is -/// `test_inputs(modulus, word_bits)` and `i` ranges from 0 to -/// `inputs.len()`. Panic if the two functions ever give -/// different answers. -pub fn run_binaryop_test_cases( - modulus: BigUint, - word_bits: usize, - op: BinaryOp, - expected_op: ExpectedOp, -) where - F: PrimeField, - BinaryOp: Fn(F, F) -> F, - ExpectedOp: Fn(BigUint, BigUint) -> BigUint, -{ - let inputs = test_inputs(modulus, word_bits); - - for i in 0..inputs.len() { - // Iterator over inputs rotated right by i places. Since - // cycle().skip(i) rotates left by i, we need to rotate by - // n_input_elts - i. - let shifted_inputs: Vec<_> = inputs - .iter() - .cycle() - .skip(inputs.len() - i) - .take(inputs.len()) - .collect(); - - // Calculate pointwise operations - let expected: Vec<_> = inputs - .iter() - .zip(shifted_inputs.clone()) - .map(|(x, y)| expected_op(x.clone(), y.clone())) - .collect(); - - let output: Vec<_> = inputs - .iter() - .zip(shifted_inputs.clone()) - .map(|(x, y)| { - op( - F::from_canonical_biguint(x.clone()), - F::from_canonical_biguint(y.clone()), - ) - .to_canonical_biguint() - }) - .collect(); - - // Compare expected outputs with actual outputs - for i in 0..inputs.len() { - assert_eq!( - output[i], expected[i], - "On inputs {} . {}, expected {} but got {}", - inputs[i], shifted_inputs[i], expected[i], output[i] - ); - } - } -} #[macro_export] macro_rules! test_field_arithmetic { ($field:ty) => { mod field_arithmetic { - use num::{bigint::BigUint, One, Zero}; + use num::bigint::BigUint; use rand::Rng; use crate::field::field_types::Field; @@ -177,18 +31,10 @@ macro_rules! test_field_arithmetic { #[test] fn negation() { - let zero = <$field>::ZERO; - let order = <$field>::order(); + type F = $field; - for i in [ - BigUint::zero(), - BigUint::one(), - BigUint::from(2u32), - &order - 1u32, - &order - 2u32, - ] { - let i_f = <$field>::from_canonical_biguint(i); - assert_eq!(i_f + -i_f, zero); + for x in [F::ZERO, F::ONE, F::TWO, F::NEG_ONE] { + assert_eq!(x + -x, F::ZERO); } } @@ -249,135 +95,3 @@ macro_rules! test_field_arithmetic { } }; } - -#[macro_export] -macro_rules! test_prime_field_arithmetic { - ($field:ty) => { - mod prime_field_arithmetic { - use std::ops::{Add, Mul, Neg, Sub}; - - use num::{bigint::BigUint, One, Zero}; - - use crate::field::field_types::Field; - - // Can be 32 or 64; doesn't have to be computer's actual word - // bits. Choosing 32 gives more tests... - const WORD_BITS: usize = 32; - - #[test] - fn arithmetic_addition() { - let modulus = <$field>::order(); - crate::field::field_testing::run_binaryop_test_cases( - modulus.clone(), - WORD_BITS, - <$field>::add, - |x, y| (&x + &y) % &modulus, - ) - } - - #[test] - fn arithmetic_subtraction() { - let modulus = <$field>::order(); - crate::field::field_testing::run_binaryop_test_cases( - modulus.clone(), - WORD_BITS, - <$field>::sub, - |x, y| { - if x >= y { - &x - &y - } else { - &modulus - &y + &x - } - }, - ) - } - - #[test] - fn arithmetic_negation() { - let modulus = <$field>::order(); - crate::field::field_testing::run_unaryop_test_cases( - modulus.clone(), - WORD_BITS, - <$field>::neg, - |x| { - if x.is_zero() { - BigUint::zero() - } else { - &modulus - &x - } - }, - ) - } - - #[test] - fn arithmetic_multiplication() { - let modulus = <$field>::order(); - crate::field::field_testing::run_binaryop_test_cases( - modulus.clone(), - WORD_BITS, - <$field>::mul, - |x, y| &x * &y % &modulus, - ) - } - - #[test] - fn arithmetic_square() { - let modulus = <$field>::order(); - crate::field::field_testing::run_unaryop_test_cases( - modulus.clone(), - WORD_BITS, - |x: $field| x.square(), - |x| (&x * &x) % &modulus, - ) - } - - #[test] - fn inversion() { - let zero = <$field>::ZERO; - let one = <$field>::ONE; - let order = <$field>::order(); - - assert_eq!(zero.try_inverse(), None); - - for x in [ - BigUint::one(), - BigUint::from(2u32), - BigUint::from(3u32), - &order - 3u32, - &order - 2u32, - &order - 1u32, - ] { - let x = <$field>::from_canonical_biguint(x); - let inv = x.inverse(); - assert_eq!(x * inv, one); - } - } - - #[test] - fn subtraction_double_wraparound() { - type F = $field; - - let (a, b) = ( - F::from_canonical_biguint((F::order() + 1u32) / 2u32), - F::TWO, - ); - let x = a * b; - assert_eq!(x, F::ONE); - assert_eq!(F::ZERO - x, F::NEG_ONE); - } - - #[test] - fn addition_double_wraparound() { - type F = $field; - - let a = F::from_canonical_biguint(u64::MAX - F::order()); - let b = F::NEG_ONE; - - let c = (a + a) + (b + b); - let d = (a + b) + (a + b); - - assert_eq!(c, d); - } - } - }; -} diff --git a/src/field/field_types.rs b/src/field/field_types.rs index d8d54d15..82f86c65 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -194,8 +194,6 @@ pub trait Field: Self::from_canonical_u64(b as u64) } - fn from_canonical_biguint(n: BigUint) -> Self; - /// Returns `n % Self::CHARACTERISTIC`. fn from_noncanonical_u128(n: u128) -> Self; @@ -328,8 +326,6 @@ pub trait PrimeField: Field { fn to_canonical_u64(&self) -> u64; fn to_noncanonical_u64(&self) -> u64; - - fn to_canonical_biguint(&self) -> BigUint; } /// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`. diff --git a/src/field/mod.rs b/src/field/mod.rs index 4011767d..6566d430 100644 --- a/src/field/mod.rs +++ b/src/field/mod.rs @@ -9,3 +9,5 @@ pub(crate) mod packed_field; #[cfg(test)] mod field_testing; +#[cfg(test)] +mod prime_field_testing; diff --git a/src/field/prime_field_testing.rs b/src/field/prime_field_testing.rs new file mode 100644 index 00000000..f6b88d0c --- /dev/null +++ b/src/field/prime_field_testing.rs @@ -0,0 +1,163 @@ +use crate::field::field_types::PrimeField; + +/// Generates a series of non-negative integers less than `modulus` which cover a range of +/// interesting test values. +pub fn test_inputs(modulus: u64) -> Vec { + const CHUNK_SIZE: u64 = 10; + + (0..CHUNK_SIZE) + .chain((1 << 31) - CHUNK_SIZE..(1 << 31) + CHUNK_SIZE) + .chain((1 << 32) - CHUNK_SIZE..(1 << 32) + CHUNK_SIZE) + .chain((1 << 63) - CHUNK_SIZE..(1 << 63) + CHUNK_SIZE) + .chain(modulus - CHUNK_SIZE..modulus) + .filter(|&x| x < modulus) + .collect() +} + +/// Apply the unary functions `op` and `expected_op` +/// coordinate-wise to the inputs from `test_inputs(modulus, +/// word_bits)` and panic if the two resulting vectors differ. +pub fn run_unaryop_test_cases(op: UnaryOp, expected_op: ExpectedOp) +where + F: PrimeField, + UnaryOp: Fn(F) -> F, + ExpectedOp: Fn(u64) -> u64, +{ + let inputs = test_inputs(F::ORDER); + let expected: Vec<_> = inputs.iter().map(|x| expected_op(x.clone())).collect(); + let output: Vec<_> = inputs + .iter() + .cloned() + .map(|x| op(F::from_canonical_u64(x)).to_canonical_u64()) + .collect(); + // Compare expected outputs with actual outputs + for i in 0..inputs.len() { + assert_eq!( + output[i], expected[i], + "Expected {}, got {} for input {}", + expected[i], output[i], inputs[i] + ); + } +} + +/// Apply the binary functions `op` and `expected_op` to each pair of inputs. +pub fn run_binaryop_test_cases(op: BinaryOp, expected_op: ExpectedOp) +where + F: PrimeField, + BinaryOp: Fn(F, F) -> F, + ExpectedOp: Fn(u64, u64) -> u64, +{ + let inputs = test_inputs(F::ORDER); + + for &lhs in &inputs { + for &rhs in &inputs { + let lhs_f = F::from_canonical_u64(lhs); + let rhs_f = F::from_canonical_u64(rhs); + let actual = op(lhs_f, rhs_f).to_canonical_u64(); + let expected = expected_op(lhs, rhs); + assert_eq!( + actual, expected, + "Expected {}, got {} for inputs ({}, {})", + expected, actual, lhs, rhs + ); + } + } +} + +#[macro_export] +macro_rules! test_prime_field_arithmetic { + ($field:ty) => { + mod prime_field_arithmetic { + use std::ops::{Add, Mul, Neg, Sub}; + + use crate::field::field_types::{Field, PrimeField}; + + #[test] + fn arithmetic_addition() { + let modulus = <$field>::ORDER; + crate::field::prime_field_testing::run_binaryop_test_cases(<$field>::add, |x, y| { + ((x as u128 + y as u128) % (modulus as u128)) as u64 + }) + } + + #[test] + fn arithmetic_subtraction() { + let modulus = <$field>::ORDER; + crate::field::prime_field_testing::run_binaryop_test_cases(<$field>::sub, |x, y| { + if x >= y { + x - y + } else { + modulus - y + x + } + }) + } + + #[test] + fn arithmetic_negation() { + let modulus = <$field>::ORDER; + crate::field::prime_field_testing::run_unaryop_test_cases(<$field>::neg, |x| { + if x == 0 { + 0 + } else { + modulus - x + } + }) + } + + #[test] + fn arithmetic_multiplication() { + let modulus = <$field>::ORDER; + crate::field::prime_field_testing::run_binaryop_test_cases(<$field>::mul, |x, y| { + ((x as u128) * (y as u128) % (modulus as u128)) as u64 + }) + } + + #[test] + fn arithmetic_square() { + let modulus = <$field>::ORDER; + crate::field::prime_field_testing::run_unaryop_test_cases( + |x: $field| x.square(), + |x| ((x as u128 * x as u128) % (modulus as u128)) as u64, + ) + } + + #[test] + fn inversion() { + let zero = <$field>::ZERO; + let one = <$field>::ONE; + let order = <$field>::ORDER; + + assert_eq!(zero.try_inverse(), None); + + for x in [1, 2, 3, order - 3, order - 2, order - 1] { + let x = <$field>::from_canonical_u64(x); + let inv = x.inverse(); + assert_eq!(x * inv, one); + } + } + + #[test] + fn subtraction_double_wraparound() { + type F = $field; + + let (a, b) = (F::from_canonical_u64((F::ORDER + 1u64) / 2u64), F::TWO); + let x = a * b; + assert_eq!(x, F::ONE); + assert_eq!(F::ZERO - x, F::NEG_ONE); + } + + #[test] + fn addition_double_wraparound() { + type F = $field; + + let a = F::from_canonical_u64(u64::MAX - F::ORDER); + let b = F::NEG_ONE; + + let c = (a + a) + (b + b); + let d = (a + b) + (a + b); + + assert_eq!(c, d); + } + } + }; +}