diff --git a/plonky2/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs b/plonky2/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs index a7f61bf5..f903cd96 100644 --- a/plonky2/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs +++ b/plonky2/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs @@ -5,6 +5,7 @@ use std::arch::asm; use plonky2_field::field_types::PrimeField; use plonky2_field::goldilocks_field::GoldilocksField; +use plonky2_util::branch_hint; use static_assertions::const_assert; use unroll::unroll_for_loops; @@ -108,6 +109,8 @@ const_assert!(check_round_const_bounds_init()); // ====================================== SCALAR ARITHMETIC ======================================= +const EPSILON: u64 = 0xffffffff; + /// Addition modulo ORDER accounting for wraparound. Correct only when a + b < 2**64 + ORDER. #[inline(always)] unsafe fn add_with_wraparound(a: u64, b: u64) -> u64 { @@ -124,39 +127,36 @@ unsafe fn add_with_wraparound(a: u64, b: u64) -> u64 { adj = lateout(reg) adj, options(pure, nomem, nostack), ); - res.wrapping_add(adj) // adj is EPSILON if wraparound occured and 0 otherwise + res + adj // adj is EPSILON if wraparound occured and 0 otherwise } -/// Addition of a and (b >> 32) modulo ORDER accounting for wraparound. +/// Subtraction of a and (b >> 32) modulo ORDER accounting for wraparound. #[inline(always)] unsafe fn sub_with_wraparound_lsr32(a: u64, b: u64) -> u64 { - let res: u64; - let adj: u64; - asm!( - "subs {res}, {a}, {b}, lsr #32", - // Set adj to 0xffffffff if subtraction underflowed and 0 otherwise. - // 'cc' for 'carry clear'. - // NB: The CF in ARM subtraction is the opposite of x86: CF set == underflow did not occur. - "csetm {adj:w}, cc", - a = in(reg) a, - b = in(reg) b, - res = lateout(reg) res, - adj = lateout(reg) adj, - options(pure, nomem, nostack), - ); - res.wrapping_sub(adj) // adj is EPSILON if underflow occured and 0 otherwise. + let b_hi = b >> 32; + // This could be done with a.overflowing_add(b_hi), but `checked_sub` signals to the compiler + // that overflow is unlikely (note: this is a standard library implementation detail, not part + // of the spec). + match a.checked_sub(b_hi) { + Some(res) => res, + None => { + // Super rare. Better off branching. + branch_hint(); + let res_wrapped = a.wrapping_sub(b_hi); + res_wrapped - EPSILON + } + } } /// Multiplication of the low word (i.e., x as u32) by EPSILON. #[inline(always)] unsafe fn mul_epsilon(x: u64) -> u64 { let res; - let epsilon: u64 = 0xffffffff; asm!( // Use UMULL to save one instruction. The compiler emits two: extract the low word and then multiply. "umull {res}, {x:w}, {epsilon:w}", x = in(reg) x, - epsilon = in(reg) epsilon, + epsilon = in(reg) EPSILON, res = lateout(reg) res, options(pure, nomem, nostack, preserves_flags), );