From db464f739e3cac1ff3eac054dd90dbfe9d43b3fc Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 10 Nov 2021 11:49:30 -0800 Subject: [PATCH] merge --- src/curve/curve_adds.rs | 41 ++++- src/curve/curve_multiplication.rs | 89 ++++++++++ src/curve/curve_summation.rs | 236 +++++++++++++++++++++++++ src/curve/curve_types.rs | 14 +- src/curve/mod.rs | 2 + src/field/extension_field/quadratic.rs | 2 + src/field/extension_field/quartic.rs | 2 + src/field/field_types.rs | 2 + src/field/goldilocks_field.rs | 2 + src/field/secp256k1.rs | 2 + 10 files changed, 376 insertions(+), 16 deletions(-) create mode 100644 src/curve/curve_multiplication.rs create mode 100644 src/curve/curve_summation.rs diff --git a/src/curve/curve_adds.rs b/src/curve/curve_adds.rs index c5fc9ba4..32a5adcc 100644 --- a/src/curve/curve_adds.rs +++ b/src/curve/curve_adds.rs @@ -1,14 +1,24 @@ use std::ops::Add; -use crate::field::field_types::Field; use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; +use crate::field::field_types::Field; impl Add> for ProjectivePoint { type Output = ProjectivePoint; fn add(self, rhs: ProjectivePoint) -> Self::Output { - let ProjectivePoint { x: x1, y: y1, z: z1, zero: zero1 } = self; - let ProjectivePoint { x: x2, y: y2, z: z2, zero: zero2 } = rhs; + let ProjectivePoint { + x: x1, + y: y1, + z: z1, + zero: zero1, + } = self; + let ProjectivePoint { + x: x2, + y: y2, + z: z2, + zero: zero2, + } = rhs; if zero1 { return rhs; @@ -52,8 +62,17 @@ impl Add> for ProjectivePoint { type Output = ProjectivePoint; fn add(self, rhs: AffinePoint) -> Self::Output { - let ProjectivePoint { x: x1, y: y1, z: z1, zero: zero1 } = self; - let AffinePoint { x: x2, y: y2, zero: zero2 } = rhs; + let ProjectivePoint { + x: x1, + y: y1, + z: z1, + zero: zero1, + } = self; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = rhs; if zero1 { return rhs.to_projective(); @@ -94,8 +113,16 @@ impl Add> for AffinePoint { type Output = ProjectivePoint; fn add(self, rhs: AffinePoint) -> Self::Output { - let AffinePoint { x: x1, y: y1, zero: zero1 } = self; - let AffinePoint { x: x2, y: y2, zero: zero2 } = rhs; + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = self; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = rhs; if zero1 { return rhs.to_projective(); diff --git a/src/curve/curve_multiplication.rs b/src/curve/curve_multiplication.rs new file mode 100644 index 00000000..e5ac0eb3 --- /dev/null +++ b/src/curve/curve_multiplication.rs @@ -0,0 +1,89 @@ +use std::ops::Mul; + +use crate::curve::curve_summation::affine_summation_batch_inversion; +use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar, ProjectivePoint}; +use crate::field::field_types::Field; + +const WINDOW_BITS: usize = 4; +const BASE: usize = 1 << WINDOW_BITS; + +fn digits_per_scalar() -> usize { + (C::ScalarField::BITS + WINDOW_BITS - 1) / WINDOW_BITS +} + +/// Precomputed state used for scalar x ProjectivePoint multiplications, +/// specific to a particular generator. +#[derive(Clone)] +pub struct MultiplicationPrecomputation { + /// [(2^w)^i] g for each i < digits_per_scalar. + powers: Vec>, +} + +impl ProjectivePoint { + pub fn mul_precompute(&self) -> MultiplicationPrecomputation { + let num_digits = digits_per_scalar::(); + let mut powers_proj = Vec::with_capacity(num_digits); + powers_proj.push(*self); + for i in 1..num_digits { + let mut power_i_proj = powers_proj[i - 1]; + for _j in 0..WINDOW_BITS { + power_i_proj = power_i_proj.double(); + } + powers_proj.push(power_i_proj); + } + + let powers = ProjectivePoint::batch_to_affine(&powers_proj); + MultiplicationPrecomputation { powers } + } + + pub fn mul_with_precomputation( + &self, + scalar: C::ScalarField, + precomputation: MultiplicationPrecomputation, + ) -> Self { + // Yao's method; see https://koclab.cs.ucsb.edu/teaching/ecc/eccPapers/Doche-ch09.pdf + let precomputed_powers = precomputation.powers; + + let digits = to_digits::(&scalar); + + let mut y = ProjectivePoint::ZERO; + let mut u = ProjectivePoint::ZERO; + for j in (1..BASE).rev() { + let mut u_summands = Vec::new(); + for (i, &digit) in digits.iter().enumerate() { + if digit == j as u64 { + u_summands.push(precomputed_powers[i]); + } + } + u = u + affine_summation_batch_inversion(u_summands); + y = y + u; + } + y + } +} + +impl Mul> for CurveScalar { + type Output = ProjectivePoint; + + fn mul(self, rhs: ProjectivePoint) -> Self::Output { + let precomputation = rhs.mul_precompute(); + rhs.mul_with_precomputation(self.0, precomputation) + } +} + +#[allow(clippy::assertions_on_constants)] +fn to_digits(x: &C::ScalarField) -> Vec { + debug_assert!( + 64 % WINDOW_BITS == 0, + "For simplicity, only power-of-two window sizes are handled for now" + ); + let digits_per_u64 = 64 / WINDOW_BITS; + let mut digits = Vec::with_capacity(digits_per_scalar::()); + for limb in x.to_biguint().to_u64_digits() { + for j in 0..digits_per_u64 { + digits.push((limb >> (j * WINDOW_BITS) as u64) % BASE as u64); + } + } + + digits +} diff --git a/src/curve/curve_summation.rs b/src/curve/curve_summation.rs new file mode 100644 index 00000000..501a4977 --- /dev/null +++ b/src/curve/curve_summation.rs @@ -0,0 +1,236 @@ +use std::iter::Sum; + +use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; +use crate::field::field_types::Field; + +impl Sum> for ProjectivePoint { + fn sum>>(iter: I) -> ProjectivePoint { + let points: Vec<_> = iter.collect(); + affine_summation_best(points) + } +} + +impl Sum for ProjectivePoint { + fn sum>>(iter: I) -> ProjectivePoint { + iter.fold(ProjectivePoint::ZERO, |acc, x| acc + x) + } +} + +pub fn affine_summation_best(summation: Vec>) -> ProjectivePoint { + let result = affine_multisummation_best(vec![summation]); + debug_assert_eq!(result.len(), 1); + result[0] +} + +pub fn affine_multisummation_best( + summations: Vec>>, +) -> Vec> { + let pairwise_sums: usize = summations.iter().map(|summation| summation.len() / 2).sum(); + + // This threshold is chosen based on data from the summation benchmarks. + if pairwise_sums < 70 { + affine_multisummation_pairwise(summations) + } else { + affine_multisummation_batch_inversion(summations) + } +} + +/// Adds each pair of points using an affine + affine = projective formula, then adds up the +/// intermediate sums using a projective formula. +pub fn affine_multisummation_pairwise( + summations: Vec>>, +) -> Vec> { + summations + .into_iter() + .map(affine_summation_pairwise) + .collect() +} + +/// Adds each pair of points using an affine + affine = projective formula, then adds up the +/// intermediate sums using a projective formula. +pub fn affine_summation_pairwise(points: Vec>) -> ProjectivePoint { + let mut reduced_points: Vec> = Vec::new(); + for chunk in points.chunks(2) { + match chunk.len() { + 1 => reduced_points.push(chunk[0].to_projective()), + 2 => reduced_points.push(chunk[0] + chunk[1]), + _ => panic!(), + } + } + // TODO: Avoid copying (deref) + reduced_points + .iter() + .fold(ProjectivePoint::ZERO, |sum, x| sum + *x) +} + +/// Computes several summations of affine points by applying an affine group law, except that the +/// divisions are batched via Montgomery's trick. +pub fn affine_summation_batch_inversion( + summation: Vec>, +) -> ProjectivePoint { + let result = affine_multisummation_batch_inversion(vec![summation]); + debug_assert_eq!(result.len(), 1); + result[0] +} + +/// Computes several summations of affine points by applying an affine group law, except that the +/// divisions are batched via Montgomery's trick. +pub fn affine_multisummation_batch_inversion( + summations: Vec>>, +) -> Vec> { + let mut elements_to_invert = Vec::new(); + + // For each pair of points, (x1, y1) and (x2, y2), that we're going to add later, we want to + // invert either y (if the points are equal) or x1 - x2 (otherwise). We will use these later. + for summation in &summations { + let n = summation.len(); + // The special case for n=0 is to avoid underflow. + let range_end = if n == 0 { 0 } else { n - 1 }; + + for i in (0..range_end).step_by(2) { + let p1 = summation[i]; + let p2 = summation[i + 1]; + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = p1; + let AffinePoint { + x: x2, + y: _y2, + zero: zero2, + } = p2; + + if zero1 || zero2 || p1 == -p2 { + // These are trivial cases where we won't need any inverse. + } else if p1 == p2 { + elements_to_invert.push(y1.double()); + } else { + elements_to_invert.push(x1 - x2); + } + } + } + + let inverses: Vec = + C::BaseField::batch_multiplicative_inverse(&elements_to_invert); + + let mut all_reduced_points = Vec::with_capacity(summations.len()); + let mut inverse_index = 0; + for summation in summations { + let n = summation.len(); + let mut reduced_points = Vec::with_capacity((n + 1) / 2); + + // The special case for n=0 is to avoid underflow. + let range_end = if n == 0 { 0 } else { n - 1 }; + + for i in (0..range_end).step_by(2) { + let p1 = summation[i]; + let p2 = summation[i + 1]; + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = p1; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = p2; + + let sum = if zero1 { + p2 + } else if zero2 { + p1 + } else if p1 == -p2 { + AffinePoint::ZERO + } else { + // It's a non-trivial case where we need one of the inverses we computed earlier. + let inverse = inverses[inverse_index]; + inverse_index += 1; + + if p1 == p2 { + // This is the doubling case. + let mut numerator = x1.square().triple(); + if C::A.is_nonzero() { + numerator = numerator + C::A; + } + let quotient = numerator * inverse; + let x3 = quotient.square() - x1.double(); + let y3 = quotient * (x1 - x3) - y1; + AffinePoint::nonzero(x3, y3) + } else { + // This is the general case. We use the incomplete addition formulas 4.3 and 4.4. + let quotient = (y1 - y2) * inverse; + let x3 = quotient.square() - x1 - x2; + let y3 = quotient * (x1 - x3) - y1; + AffinePoint::nonzero(x3, y3) + } + }; + reduced_points.push(sum); + } + + // If n is odd, the last point was not part of a pair. + if n % 2 == 1 { + reduced_points.push(summation[n - 1]); + } + + all_reduced_points.push(reduced_points); + } + + // We should have consumed all of the inverses from the batch computation. + debug_assert_eq!(inverse_index, inverses.len()); + + // Recurse with our smaller set of points. + affine_multisummation_best(all_reduced_points) +} + +#[cfg(test)] +mod tests { + use crate::{ + affine_summation_batch_inversion, affine_summation_pairwise, Bls12377, Curve, + ProjectivePoint, + }; + + #[test] + fn test_pairwise_affine_summation() { + let g_affine = Bls12377::GENERATOR_AFFINE; + let g2_affine = (g_affine + g_affine).to_affine(); + let g3_affine = (g_affine + g_affine + g_affine).to_affine(); + let g2_proj = g2_affine.to_projective(); + let g3_proj = g3_affine.to_projective(); + assert_eq!( + affine_summation_pairwise::(vec![g_affine, g_affine]), + g2_proj + ); + assert_eq!( + affine_summation_pairwise::(vec![g_affine, g2_affine]), + g3_proj + ); + assert_eq!( + affine_summation_pairwise::(vec![g_affine, g_affine, g_affine]), + g3_proj + ); + assert_eq!( + affine_summation_pairwise::(vec![]), + ProjectivePoint::ZERO + ); + } + + #[test] + fn test_pairwise_affine_summation_batch_inversion() { + let g = Bls12377::GENERATOR_AFFINE; + let g_proj = g.to_projective(); + assert_eq!( + affine_summation_batch_inversion::(vec![g, g]), + g_proj + g_proj + ); + assert_eq!( + affine_summation_batch_inversion::(vec![g, g, g]), + g_proj + g_proj + g_proj + ); + assert_eq!( + affine_summation_batch_inversion::(vec![]), + ProjectivePoint::ZERO + ); + } +} diff --git a/src/curve/curve_types.rs b/src/curve/curve_types.rs index f7d55c0c..830dc7c1 100644 --- a/src/curve/curve_types.rs +++ b/src/curve/curve_types.rs @@ -1,9 +1,9 @@ +use std::fmt::Debug; use std::ops::Neg; use anyhow::Result; use crate::field::field_types::Field; -use std::fmt::Debug; // To avoid implementation conflicts from associated types, // see https://github.com/rust-lang/rust/issues/20400 @@ -54,9 +54,10 @@ pub trait Curve: 'static + Sync + Sized + Copy + Debug { Ok(res) }*/ - fn is_safe_curve() -> bool{ + fn is_safe_curve() -> bool { // Added additional check to prevent using vulnerabilties in case a discriminant is equal to 0. - (Self::A.cube().double().double() + Self::B.square().triple().triple().triple()).is_nonzero() + (Self::A.cube().double().double() + Self::B.square().triple().triple().triple()) + .is_nonzero() } } @@ -101,11 +102,7 @@ impl AffinePoint { } pub fn double(&self) -> Self { - let AffinePoint { - x: x1, - y: y1, - zero, - } = *self; + let AffinePoint { x: x1, y: y1, zero } = *self; if zero { return AffinePoint::ZERO; @@ -124,7 +121,6 @@ impl AffinePoint { zero: false, } } - } impl PartialEq for AffinePoint { diff --git a/src/curve/mod.rs b/src/curve/mod.rs index 1e536564..8b9df88e 100644 --- a/src/curve/mod.rs +++ b/src/curve/mod.rs @@ -1,2 +1,4 @@ pub mod curve_adds; +pub mod curve_multiplication; +pub mod curve_summation; pub mod curve_types; diff --git a/src/field/extension_field/quadratic.rs b/src/field/extension_field/quadratic.rs index e2794330..b724095a 100644 --- a/src/field/extension_field/quadratic.rs +++ b/src/field/extension_field/quadratic.rs @@ -67,6 +67,8 @@ impl> Field for QuadraticExtension { const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR); const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR); + const BITS: usize = F::BITS * 2; + fn order() -> BigUint { F::order() * F::order() } diff --git a/src/field/extension_field/quartic.rs b/src/field/extension_field/quartic.rs index 01918ff3..0d221401 100644 --- a/src/field/extension_field/quartic.rs +++ b/src/field/extension_field/quartic.rs @@ -69,6 +69,8 @@ impl> Field for QuarticExtension { const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(F::EXT_MULTIPLICATIVE_GROUP_GENERATOR); const POWER_OF_TWO_GENERATOR: Self = Self(F::EXT_POWER_OF_TWO_GENERATOR); + const BITS: usize = F::BITS * 4; + fn order() -> BigUint { F::order().pow(4u32) } diff --git a/src/field/field_types.rs b/src/field/field_types.rs index 250338fb..80eeecff 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -59,6 +59,8 @@ pub trait Field: /// Generator of a multiplicative subgroup of order `2^TWO_ADICITY`. const POWER_OF_TWO_GENERATOR: Self; + const BITS: usize; + fn order() -> BigUint; #[inline] diff --git a/src/field/goldilocks_field.rs b/src/field/goldilocks_field.rs index cb85d56d..058b6db8 100644 --- a/src/field/goldilocks_field.rs +++ b/src/field/goldilocks_field.rs @@ -82,6 +82,8 @@ impl Field for GoldilocksField { // ``` const POWER_OF_TWO_GENERATOR: Self = Self(1753635133440165772); + const BITS: usize = 64; + fn order() -> BigUint { Self::ORDER.into() } diff --git a/src/field/secp256k1.rs b/src/field/secp256k1.rs index 75221a1f..acb1df4e 100644 --- a/src/field/secp256k1.rs +++ b/src/field/secp256k1.rs @@ -91,6 +91,8 @@ impl Field for Secp256K1Base { // Sage: `g_2 = g^((p - 1) / 2)` const POWER_OF_TWO_GENERATOR: Self = Self::NEG_ONE; + const BITS: usize = 256; + fn order() -> BigUint { BigUint::from_slice(&[ 0xFFFFFC2F, 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,