diff --git a/src/field/goldilocks_field.rs b/src/field/goldilocks_field.rs index 6b13dd80..16f25037 100644 --- a/src/field/goldilocks_field.rs +++ b/src/field/goldilocks_field.rs @@ -13,6 +13,7 @@ use crate::field::extension_field::quartic::QuarticExtension; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, PrimeField, RichField}; use crate::field::inversion::try_inverse_u64; +use crate::util::{assume, branch_hint}; const EPSILON: u64 = (1 << 32) - 1; @@ -136,13 +137,15 @@ impl PrimeField for GoldilocksField { #[inline] unsafe fn add_canonical_u64(&self, rhs: u64) -> Self { let (res_wrapped, carry) = self.0.overflowing_add(rhs); - Self(res_wrapped.wrapping_add(EPSILON * (carry as u64))) + // Add EPSILON * carry cannot overflow unless rhs is not in canonical form. + Self(res_wrapped + EPSILON * (carry as u64)) } #[inline] unsafe fn sub_canonical_u64(&self, rhs: u64) -> Self { let (res_wrapped, borrow) = self.0.overflowing_sub(rhs); - Self(res_wrapped.wrapping_sub(EPSILON * (borrow as u64))) + // Sub EPSILON * carry cannot underflow unless rhs is not in canonical form. + Self(res_wrapped - EPSILON * (borrow as u64)) } } @@ -165,8 +168,21 @@ impl Add for GoldilocksField { #[inline] #[allow(clippy::suspicious_arithmetic_impl)] fn add(self, rhs: Self) -> Self { - let (sum, over) = self.0.overflowing_add(rhs.to_canonical_u64()); - Self(sum.wrapping_sub((over as u64) * Self::ORDER)) + let (sum, over) = self.0.overflowing_add(rhs.0); + let (mut sum, over) = sum.overflowing_add((over as u64) * EPSILON); + if over { + // NB: self.0 > Self::ORDER && rhs.0 > Self::ORDER is necessary but not sufficient for + // double-overflow. + // This assume does two things: + // 1. If compiler knows that either self.0 or rhs.0 <= ORDER, then it can skip this + // check. + // 2. Hints to the compiler how rare this double-overflow is (thus handled better with + // a branch). + assume(self.0 > Self::ORDER && rhs.0 > Self::ORDER); + branch_hint(); + sum += EPSILON; // Cannot overflow. + } + Self(sum) } } @@ -189,8 +205,21 @@ impl Sub for GoldilocksField { #[inline] #[allow(clippy::suspicious_arithmetic_impl)] fn sub(self, rhs: Self) -> Self { - let (diff, under) = self.0.overflowing_sub(rhs.to_canonical_u64()); - Self(diff.wrapping_add((under as u64) * Self::ORDER)) + let (diff, under) = self.0.overflowing_sub(rhs.0); + let (mut diff, under) = diff.overflowing_sub((under as u64) * EPSILON); + if under { + // NB: self.0 < EPSILON - 1 && rhs.0 > Self::ORDER is necessary but not sufficient for + // double-underflow. + // This assume does two things: + // 1. If compiler knows that either self.0 >= EPSILON - 1 or rhs.0 <= ORDER, then it + // can skip this check. + // 2. Hints to the compiler how rare this double-underflow is (thus handled better + // with a branch). + assume(self.0 < EPSILON - 1 && rhs.0 > Self::ORDER); + branch_hint(); + diff -= EPSILON; // Cannot underflow. + } + Self(diff) } } @@ -283,8 +312,6 @@ impl RichField for GoldilocksField {} #[inline(always)] #[cfg(target_arch = "x86_64")] unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { - use crate::util::assume; - let res_wrapped: u64; let adjustment: u64; asm!( @@ -304,14 +331,17 @@ unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { ); assume(x != 0 || (res_wrapped == y && adjustment == 0)); assume(y != 0 || (res_wrapped == x && adjustment == 0)); - res_wrapped.wrapping_add(adjustment) // Add EPSILON == subtract ORDER. + // Add EPSILON == subtract ORDER. + // Cannot overflow unless the assumption if x + y < 2**64 + ORDER is incorrect. + res_wrapped + adjustment } #[inline(always)] #[cfg(not(target_arch = "x86_64"))] unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { let (res_wrapped, carry) = x.overflowing_add(y); - res_wrapped.wrapping_add(EPSILON * (carry as u64)) + // Below cannot overflow unless the assumption if x + y < 2**64 + ORDER is incorrect. + res_wrapped + EPSILON * (carry as u64) } /// Fast subtraction modulo ORDER for x86-64. @@ -322,8 +352,6 @@ unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { #[inline(always)] #[cfg(target_arch = "x86_64")] unsafe fn sub_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { - use crate::util::assume; - let res_wrapped: u64; let adjustment: u64; asm!( @@ -334,14 +362,17 @@ unsafe fn sub_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { options(pure, nomem, nostack), ); assume(y != 0 || (res_wrapped == x && adjustment == 0)); - res_wrapped.wrapping_sub(adjustment) // Subtract EPSILON == add ORDER. + // Subtract EPSILON == add ORDER. + // Cannot underflow unless the assumption x - y >= -ORDER is incorrect. + res_wrapped - adjustment } #[inline(always)] #[cfg(not(target_arch = "x86_64"))] unsafe fn sub_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { let (res_wrapped, borrow) = x.overflowing_sub(y); - res_wrapped.wrapping_sub(EPSILON * (borrow as u64)) + // Below cannot underflow unless the assumption x - y >= -ORDER is incorrect. + res_wrapped - EPSILON * (borrow as u64) } /// Reduces to a 64-bit value. The result might not be in canonical form; it could be in between the diff --git a/src/util/mod.rs b/src/util/mod.rs index 94df5141..586033be 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -198,10 +198,27 @@ pub(crate) fn reverse_bits(n: usize, num_bits: usize) -> usize { } #[inline(always)] -pub(crate) unsafe fn assume(p: bool) { +pub(crate) fn assume(p: bool) { debug_assert!(p); if !p { - unreachable_unchecked(); + unsafe { + unreachable_unchecked(); + } + } +} + +/// Try to force Rust to emit a branch. Example: +/// if x > 2 { +/// y = foo(); +/// branch_hint(); +/// } else { +/// y = bar(); +/// } +/// This function has no semantics. It is a hint only. +#[inline(always)] +pub(crate) fn branch_hint() { + unsafe { + asm!("", options(nomem, nostack, preserves_flags)); } }