plonky2/src/field/prime_field_testing.rs
Daniel Lubarov a2eaaceb34
Rework the field test code a bit (#225)
- Split it into two files, one for general `Field` tests and one for `PrimeField` tests.
- Replace most uses of `BigUint` in tests with `u64`. These uses were only applicable for `PrimeField`s, which are 64-bit fields anyway. This lets us delete the `BigUInt` conversion methods.
- Simplify `test_inputs`, which was originally written for large prime fields. Now that it's only used for 64-bit fields, I think interesting inputs are just the smallest and largest elements, and those close to 2^32 etc.
2021-09-07 14:17:15 -07:00

164 lines
5.2 KiB
Rust

use crate::field::field_types::PrimeField;
/// Generates a series of non-negative integers less than `modulus` which cover a range of
/// interesting test values.
pub fn test_inputs(modulus: u64) -> Vec<u64> {
const CHUNK_SIZE: u64 = 10;
(0..CHUNK_SIZE)
.chain((1 << 31) - CHUNK_SIZE..(1 << 31) + CHUNK_SIZE)
.chain((1 << 32) - CHUNK_SIZE..(1 << 32) + CHUNK_SIZE)
.chain((1 << 63) - CHUNK_SIZE..(1 << 63) + CHUNK_SIZE)
.chain(modulus - CHUNK_SIZE..modulus)
.filter(|&x| x < modulus)
.collect()
}
/// Apply the unary functions `op` and `expected_op`
/// coordinate-wise to the inputs from `test_inputs(modulus,
/// word_bits)` and panic if the two resulting vectors differ.
pub fn run_unaryop_test_cases<F, UnaryOp, ExpectedOp>(op: UnaryOp, expected_op: ExpectedOp)
where
F: PrimeField,
UnaryOp: Fn(F) -> F,
ExpectedOp: Fn(u64) -> u64,
{
let inputs = test_inputs(F::ORDER);
let expected: Vec<_> = inputs.iter().map(|x| expected_op(x.clone())).collect();
let output: Vec<_> = inputs
.iter()
.cloned()
.map(|x| op(F::from_canonical_u64(x)).to_canonical_u64())
.collect();
// Compare expected outputs with actual outputs
for i in 0..inputs.len() {
assert_eq!(
output[i], expected[i],
"Expected {}, got {} for input {}",
expected[i], output[i], inputs[i]
);
}
}
/// Apply the binary functions `op` and `expected_op` to each pair of inputs.
pub fn run_binaryop_test_cases<F, BinaryOp, ExpectedOp>(op: BinaryOp, expected_op: ExpectedOp)
where
F: PrimeField,
BinaryOp: Fn(F, F) -> F,
ExpectedOp: Fn(u64, u64) -> u64,
{
let inputs = test_inputs(F::ORDER);
for &lhs in &inputs {
for &rhs in &inputs {
let lhs_f = F::from_canonical_u64(lhs);
let rhs_f = F::from_canonical_u64(rhs);
let actual = op(lhs_f, rhs_f).to_canonical_u64();
let expected = expected_op(lhs, rhs);
assert_eq!(
actual, expected,
"Expected {}, got {} for inputs ({}, {})",
expected, actual, lhs, rhs
);
}
}
}
#[macro_export]
macro_rules! test_prime_field_arithmetic {
($field:ty) => {
mod prime_field_arithmetic {
use std::ops::{Add, Mul, Neg, Sub};
use crate::field::field_types::{Field, PrimeField};
#[test]
fn arithmetic_addition() {
let modulus = <$field>::ORDER;
crate::field::prime_field_testing::run_binaryop_test_cases(<$field>::add, |x, y| {
((x as u128 + y as u128) % (modulus as u128)) as u64
})
}
#[test]
fn arithmetic_subtraction() {
let modulus = <$field>::ORDER;
crate::field::prime_field_testing::run_binaryop_test_cases(<$field>::sub, |x, y| {
if x >= y {
x - y
} else {
modulus - y + x
}
})
}
#[test]
fn arithmetic_negation() {
let modulus = <$field>::ORDER;
crate::field::prime_field_testing::run_unaryop_test_cases(<$field>::neg, |x| {
if x == 0 {
0
} else {
modulus - x
}
})
}
#[test]
fn arithmetic_multiplication() {
let modulus = <$field>::ORDER;
crate::field::prime_field_testing::run_binaryop_test_cases(<$field>::mul, |x, y| {
((x as u128) * (y as u128) % (modulus as u128)) as u64
})
}
#[test]
fn arithmetic_square() {
let modulus = <$field>::ORDER;
crate::field::prime_field_testing::run_unaryop_test_cases(
|x: $field| x.square(),
|x| ((x as u128 * x as u128) % (modulus as u128)) as u64,
)
}
#[test]
fn inversion() {
let zero = <$field>::ZERO;
let one = <$field>::ONE;
let order = <$field>::ORDER;
assert_eq!(zero.try_inverse(), None);
for x in [1, 2, 3, order - 3, order - 2, order - 1] {
let x = <$field>::from_canonical_u64(x);
let inv = x.inverse();
assert_eq!(x * inv, one);
}
}
#[test]
fn subtraction_double_wraparound() {
type F = $field;
let (a, b) = (F::from_canonical_u64((F::ORDER + 1u64) / 2u64), F::TWO);
let x = a * b;
assert_eq!(x, F::ONE);
assert_eq!(F::ZERO - x, F::NEG_ONE);
}
#[test]
fn addition_double_wraparound() {
type F = $field;
let a = F::from_canonical_u64(u64::MAX - F::ORDER);
let b = F::NEG_ONE;
let c = (a + a) + (b + b);
let d = (a + b) + (a + b);
assert_eq!(c, d);
}
}
};
}