Small optimizations (#319)

* Small optimizations

* Small optimizations

* feedback

* inline

* feedback

* fix unused import
This commit is contained in:
Daniel Lubarov 2021-10-22 19:11:05 -07:00 committed by GitHub
parent db23416b04
commit 806641d13f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 71 additions and 16 deletions

View File

@ -86,6 +86,7 @@ impl Field for CrandallField {
#[inline]
fn from_canonical_u64(n: u64) -> Self {
debug_assert!(n < Self::ORDER);
Self(n)
}
@ -219,6 +220,12 @@ impl PrimeField for CrandallField {
let (sum, over) = self.0.overflowing_add(rhs);
Self(sum.wrapping_sub((over as u64) * Self::ORDER))
}
#[inline]
unsafe fn sub_canonical_u64(&self, rhs: u64) -> Self {
let (sum, under) = self.0.overflowing_sub(rhs);
Self(sum.wrapping_add((under as u64) * Self::ORDER))
}
}
impl Neg for CrandallField {

View File

@ -62,26 +62,32 @@ pub trait Field:
fn order() -> BigUint;
#[inline]
fn is_zero(&self) -> bool {
*self == Self::ZERO
}
#[inline]
fn is_nonzero(&self) -> bool {
*self != Self::ZERO
}
#[inline]
fn is_one(&self) -> bool {
*self == Self::ONE
}
#[inline]
fn double(&self) -> Self {
*self + *self
}
#[inline]
fn square(&self) -> Self {
*self * *self
}
#[inline]
fn cube(&self) -> Self {
self.square() * *self
}
@ -340,6 +346,16 @@ pub trait PrimeField: Field {
fn from_noncanonical_u64(n: u64) -> Self;
#[inline]
fn add_one(&self) -> Self {
unsafe { self.add_canonical_u64(1) }
}
#[inline]
fn sub_one(&self) -> Self {
unsafe { self.sub_canonical_u64(1) }
}
/// Equivalent to *self + Self::from_canonical_u64(rhs), but may be cheaper. The caller must
/// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this
/// precondition is not met. It is marked unsafe for this reason.
@ -348,6 +364,15 @@ pub trait PrimeField: Field {
// Default implementation.
*self + Self::from_canonical_u64(rhs)
}
/// Equivalent to *self - Self::from_canonical_u64(rhs), but may be cheaper. The caller must
/// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this
/// precondition is not met. It is marked unsafe for this reason.
#[inline]
unsafe fn sub_canonical_u64(&self, rhs: u64) -> Self {
// Default implementation.
*self - Self::from_canonical_u64(rhs)
}
}
/// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`.

View File

@ -13,7 +13,6 @@ use crate::field::extension_field::quartic::QuarticExtension;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, PrimeField, RichField};
use crate::field::inversion::try_inverse_u64;
use crate::util::assume;
const EPSILON: u64 = (1 << 32) - 1;
@ -93,6 +92,7 @@ impl Field for GoldilocksField {
#[inline]
fn from_canonical_u64(n: u64) -> Self {
debug_assert!(n < Self::ORDER);
Self(n)
}
@ -103,6 +103,12 @@ impl Field for GoldilocksField {
fn rand_from_rng<R: Rng>(rng: &mut R) -> Self {
Self::from_canonical_u64(rng.gen_range(0..Self::ORDER))
}
#[inline]
fn multiply_accumulate(&self, x: Self, y: Self) -> Self {
// u64 + u64 * u64 cannot overflow.
reduce128((self.0 as u128) + (x.0 as u128) * (y.0 as u128))
}
}
impl PrimeField for GoldilocksField {
@ -126,6 +132,18 @@ impl PrimeField for GoldilocksField {
fn from_noncanonical_u64(n: u64) -> Self {
Self(n)
}
#[inline]
unsafe fn add_canonical_u64(&self, rhs: u64) -> Self {
let (res_wrapped, carry) = self.0.overflowing_add(rhs);
Self(res_wrapped.wrapping_add(EPSILON * (carry as u64)))
}
#[inline]
unsafe fn sub_canonical_u64(&self, rhs: u64) -> Self {
let (res_wrapped, borrow) = self.0.overflowing_sub(rhs);
Self(res_wrapped.wrapping_sub(EPSILON * (borrow as u64)))
}
}
impl Neg for GoldilocksField {
@ -153,6 +171,7 @@ impl Add for GoldilocksField {
}
impl AddAssign for GoldilocksField {
#[inline]
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
@ -263,7 +282,9 @@ impl RichField for GoldilocksField {}
/// the registers, so its use is not recommended when either input will be used again.
#[inline(always)]
#[cfg(target_arch = "x86_64")]
unsafe fn add_with_wraparound(x: u64, y: u64) -> u64 {
unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 {
use crate::util::assume;
let res_wrapped: u64;
let adjustment: u64;
asm!(
@ -288,7 +309,7 @@ unsafe fn add_with_wraparound(x: u64, y: u64) -> u64 {
#[inline(always)]
#[cfg(not(target_arch = "x86_64"))]
unsafe fn add_with_wraparound(x: u64, y: u64) -> u64 {
unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 {
let (res_wrapped, carry) = x.overflowing_add(y);
res_wrapped.wrapping_add(EPSILON * (carry as u64))
}
@ -300,12 +321,14 @@ unsafe fn add_with_wraparound(x: u64, y: u64) -> u64 {
/// the registers, so its use is not recommended when either input will be used again.
#[inline(always)]
#[cfg(target_arch = "x86_64")]
unsafe fn sub_with_wraparound(x: u64, y: u64) -> u64 {
unsafe fn sub_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 {
use crate::util::assume;
let res_wrapped: u64;
let adjustment: u64;
asm!(
"sub {0}, {1}",
"sbb {1:e}, {1:e}", // See add_with_wraparound.
"sbb {1:e}, {1:e}", // See add_no_canonicalize_trashing_input.
inlateout(reg) x => res_wrapped,
inlateout(reg) y => adjustment,
options(pure, nomem, nostack),
@ -316,7 +339,7 @@ unsafe fn sub_with_wraparound(x: u64, y: u64) -> u64 {
#[inline(always)]
#[cfg(not(target_arch = "x86_64"))]
unsafe fn sub_with_wraparound(x: u64, y: u64) -> u64 {
unsafe fn sub_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 {
let (res_wrapped, borrow) = x.overflowing_sub(y);
res_wrapped.wrapping_sub(EPSILON * (borrow as u64))
}
@ -329,9 +352,9 @@ fn reduce128(x: u128) -> GoldilocksField {
let x_hi_hi = x_hi >> 32;
let x_hi_lo = x_hi & EPSILON;
let t0 = unsafe { sub_with_wraparound(x_lo, x_hi_hi) };
let t0 = unsafe { sub_no_canonicalize_trashing_input(x_lo, x_hi_hi) };
let t1 = x_hi_lo * EPSILON;
let t2 = unsafe { add_with_wraparound(t0, t1) };
let t2 = unsafe { add_no_canonicalize_trashing_input(t0, t1) };
GoldilocksField(t2)
}

View File

@ -68,7 +68,7 @@ impl<F: RichField + Extendable<D>, const D: usize, const B: usize> Gate<F, D> fo
let constraints_iter = limbs.iter().map(|&limb| {
(0..B)
.map(|i| limb - F::from_canonical_usize(i))
.map(|i| unsafe { limb.sub_canonical_u64(i as u64) })
.product::<F>()
});
constraints.extend(constraints_iter);

View File

@ -118,7 +118,7 @@ impl<F: RichField + Extendable<D> + GMiMC<WIDTH>, const D: usize, const WIDTH: u
// Assert that `swap` is binary.
let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(swap * (swap - F::ONE));
constraints.push(swap * swap.sub_one());
let mut state = Vec::with_capacity(12);
for i in 0..4 {

View File

@ -3,7 +3,7 @@ use std::marker::PhantomData;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, RichField};
use crate::field::field_types::{Field, PrimeField, RichField};
use crate::gates::gate::Gate;
use crate::hash::poseidon;
use crate::hash::poseidon::Poseidon;
@ -180,7 +180,7 @@ where
// Assert that `swap` is binary.
let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(swap * (swap - F::ONE));
constraints.push(swap * swap.sub_one());
let mut state = Vec::with_capacity(WIDTH);
for i in 0..4 {

View File

@ -477,8 +477,8 @@ where
#[inline(always)]
fn sbox_monomial<F: FieldExtension<D, BaseField = Self>, const D: usize>(x: F) -> F {
// x |--> x^7
let x2 = x * x;
let x4 = x2 * x2;
let x2 = x.square();
let x4 = x2.square();
let x3 = x * x2;
x3 * x4
}

View File

@ -1,6 +1,6 @@
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::{Field, RichField};
use crate::field::field_types::{Field, PrimeField, RichField};
use crate::gates::gate::PrefixedGate;
use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;
@ -140,7 +140,7 @@ pub(crate) fn eval_vanishing_poly_base<F: RichField + Extendable<D>, const D: us
for i in 0..num_challenges {
let z_x = local_zs[i];
let z_gz = next_zs[i];
vanishing_z_1_terms.push(l1_x * (z_x - F::ONE));
vanishing_z_1_terms.push(l1_x * z_x.sub_one());
numerator_values.extend((0..num_routed_wires).map(|j| {
let wire_value = vars.local_wires[j];