diff --git a/src/field/crandall_field.rs b/src/field/crandall_field.rs index 7e7e248b..a43de0d8 100644 --- a/src/field/crandall_field.rs +++ b/src/field/crandall_field.rs @@ -372,7 +372,7 @@ impl Add for CrandallField { #[inline] #[allow(clippy::suspicious_arithmetic_impl)] fn add(self, rhs: Self) -> Self { - let (sum, over) = self.0.overflowing_add(rhs.0); + let (sum, over) = self.0.overflowing_add(rhs.to_canonical_u64()); Self(sum.overflowing_sub((over as u64) * FIELD_ORDER).0) } } @@ -452,6 +452,17 @@ impl Extendable<4> for CrandallField { type Extension = QuarticCrandallField; } +/// Faster addition for when we know that lhs.0 + rhs.0 < 2^64 + FIELD_ORDER. If this is the case, +/// then the .to_canonical_u64() that addition usually performs is unnecessary. Omitting it saves +/// three instructions. +/// This function is marked unsafe because it may yield incorrect result if the condition is not +/// satisfied. +#[inline] +unsafe fn add_no_canonicalize(lhs: CrandallField, rhs: CrandallField) -> CrandallField { + let (sum, over) = lhs.0.overflowing_add(rhs.0); + CrandallField(sum.overflowing_sub((over as u64) * FIELD_ORDER).0) +} + /// Reduces to a 64-bit value. The result might not be in canonical form; it could be in between the /// field order and `2^64`. #[inline] @@ -465,7 +476,12 @@ fn reduce128(x: u128) -> CrandallField { let (lo_2, hi_2) = split((EPSILON as u128) * (hi_1 as u128) + (lo_1 as u128)); let lo_3 = hi_2 * EPSILON; - CrandallField(lo_2) + CrandallField(lo_3) + unsafe { + // This is safe to do because lo_2 + lo_3 < 2^64 + FIELD_ORDER. Notice that hi_2 <= + // 2^32 - 1. Then lo_3 = hi_2 * EPSILON <= (2^32 - 1) * EPSILON < FIELD_ORDER. + // Use of standard addition here would make multiplication 20% more expensive. + add_no_canonicalize(CrandallField(lo_2), CrandallField(lo_3)) + } } #[inline] diff --git a/src/field/field_testing.rs b/src/field/field_testing.rs index 88754d68..51275bf4 100644 --- a/src/field/field_testing.rs +++ b/src/field/field_testing.rs @@ -364,7 +364,7 @@ macro_rules! test_prime_field_arithmetic { } #[test] - fn subtraction() { + fn subtraction_double_wraparound() { type F = $field; let (a, b) = ( @@ -375,6 +375,19 @@ macro_rules! test_prime_field_arithmetic { assert_eq!(x, F::ONE); assert_eq!(F::ZERO - x, F::NEG_ONE); } + + #[test] + fn addition_double_wraparound() { + type F = $field; + + let a = F::from_canonical_biguint(u64::MAX - F::order()); + let b = F::NEG_ONE; + + let c = (a + a) + (b + b); + let d = (a + b) + (a + b); + + assert_eq!(c, d); + } } }; }