From 284f9a412ca385959a8e156541276fced5e80e1c Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 18 Nov 2021 10:30:57 -0800 Subject: [PATCH] curve multiply; test for curve add; addressed comments --- src/curve/curve_multiplication.rs | 10 +- src/curve/curve_types.rs | 31 +----- src/gadgets/biguint.rs | 11 ++ src/gadgets/curve.rs | 179 +++++++++++++++++++++++++++--- src/gadgets/nonnative.rs | 32 ++++++ 5 files changed, 220 insertions(+), 43 deletions(-) diff --git a/src/curve/curve_multiplication.rs b/src/curve/curve_multiplication.rs index e5ac0eb3..b09b8a0f 100644 --- a/src/curve/curve_multiplication.rs +++ b/src/curve/curve_multiplication.rs @@ -1,6 +1,6 @@ use std::ops::Mul; -use crate::curve::curve_summation::affine_summation_batch_inversion; +use crate::curve::curve_summation::affine_multisummation_batch_inversion; use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar, ProjectivePoint}; use crate::field::field_types::Field; @@ -48,6 +48,7 @@ impl ProjectivePoint { let mut y = ProjectivePoint::ZERO; let mut u = ProjectivePoint::ZERO; + let mut all_summands = Vec::new(); for j in (1..BASE).rev() { let mut u_summands = Vec::new(); for (i, &digit) in digits.iter().enumerate() { @@ -55,7 +56,12 @@ impl ProjectivePoint { u_summands.push(precomputed_powers[i]); } } - u = u + affine_summation_batch_inversion(u_summands); + all_summands.push(u_summands); + } + + let all_sums = affine_multisummation_batch_inversion(all_summands); + for i in 0..all_sums.len() { + u = u + all_sums[i]; y = y + u; } y diff --git a/src/curve/curve_types.rs b/src/curve/curve_types.rs index f2bb24b5..c9a04ab2 100644 --- a/src/curve/curve_types.rs +++ b/src/curve/curve_types.rs @@ -1,8 +1,6 @@ use std::fmt::Debug; use std::ops::Neg; -use anyhow::Result; - use crate::field::field_types::Field; // To avoid implementation conflicts from associated types, @@ -29,30 +27,6 @@ pub trait Curve: 'static + Sync + Sized + Copy + Debug { CurveScalar(x) } - /*fn try_convert_b2s(x: Self::BaseField) -> Result { - x.try_convert::() - } - - fn try_convert_s2b(x: Self::ScalarField) -> Result { - x.try_convert::() - } - - fn try_convert_s2b_slice(s: &[Self::ScalarField]) -> Result> { - let mut res = Vec::with_capacity(s.len()); - for &x in s { - res.push(Self::try_convert_s2b(x)?); - } - Ok(res) - } - - fn try_convert_b2s_slice(s: &[Self::BaseField]) -> Result> { - let mut res = Vec::with_capacity(s.len()); - for &x in s { - res.push(Self::try_convert_b2s(x)?); - } - Ok(res) - }*/ - fn is_safe_curve() -> bool { // Added additional check to prevent using vulnerabilties in case a discriminant is equal to 0. (Self::A.cube().double().double() + Self::B.square().triple().triple().triple()) @@ -155,7 +129,7 @@ pub struct ProjectivePoint { impl ProjectivePoint { pub const ZERO: Self = Self { x: C::BaseField::ZERO, - y: C::BaseField::ZERO, + y: C::BaseField::ONE, z: C::BaseField::ZERO, }; @@ -166,7 +140,8 @@ impl ProjectivePoint { } pub fn is_valid(&self) -> bool { - self.to_affine().is_valid() + let Self { x, y, z } = *self; + z.is_zero() || y.square() * z == x.cube() + C::A * x * z.square() + C::B * z.cube() } pub fn to_affine(&self) -> AffinePoint { diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 3aa96235..b67c85a5 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -155,6 +155,17 @@ impl, const D: usize> CircuitBuilder { } } + // Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function). + pub fn mul_add_biguint( + &mut self, + x: &BigUintTarget, + y: &BigUintTarget, + z: &BigUintTarget, + ) -> BigUintTarget { + let prod = self.mul_biguint(x, y); + self.add_biguint(&prod, z) + } + pub fn div_rem_biguint( &mut self, a: &BigUintTarget, diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index 5a458a56..eda0e5e0 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -1,6 +1,6 @@ use crate::curve::curve_types::{AffinePoint, Curve}; use crate::field::extension_field::Extendable; -use crate::field::field_types::RichField; +use crate::field::field_types::{Field, RichField}; use crate::gadgets::nonnative::NonNativeTarget; use crate::plonk::circuit_builder::CircuitBuilder; @@ -18,6 +18,17 @@ impl AffinePointTarget { } } +const WINDOW_BITS: usize = 4; +const BASE: usize = 1 << WINDOW_BITS; + +fn digits_per_scalar() -> usize { + (C::ScalarField::BITS + WINDOW_BITS - 1) / WINDOW_BITS +} + +pub struct MulPrecomputationTarget { + powers: Vec>, +} + impl, const D: usize> CircuitBuilder { pub fn constant_affine_point( &mut self, @@ -39,6 +50,13 @@ impl, const D: usize> CircuitBuilder { self.connect_nonnative(&lhs.y, &rhs.y); } + pub fn add_virtual_affine_point_target(&mut self) -> AffinePointTarget { + let x = self.add_virtual_nonnative_target(); + let y = self.add_virtual_nonnative_target(); + + AffinePointTarget { x, y } + } + pub fn curve_assert_valid(&mut self, p: &AffinePointTarget) { let a = self.constant_nonnative(C::A); let b = self.constant_nonnative(C::B); @@ -61,11 +79,7 @@ impl, const D: usize> CircuitBuilder { } } - pub fn curve_double( - &mut self, - p: &AffinePointTarget, - p_orig: AffinePoint, - ) -> AffinePointTarget { + pub fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget { let AffinePointTarget { x, y } = p; let double_y = self.add_nonnative(y, y); let inv_double_y = self.inv_nonnative(&double_y); @@ -89,6 +103,7 @@ impl, const D: usize> CircuitBuilder { AffinePointTarget { x: x3, y: y3 } } + // Add two points, which are assumed to be non-equal. pub fn curve_add( &mut self, p1: &AffinePointTarget, @@ -122,6 +137,110 @@ impl, const D: usize> CircuitBuilder { y: y3_norm, } } + + pub fn mul_precompute( + &mut self, + p: &AffinePointTarget, + ) -> MulPrecomputationTarget { + let num_digits = digits_per_scalar::(); + + let mut powers = Vec::with_capacity(num_digits); + powers.push(p.clone()); + for i in 1..num_digits { + let mut power_i = powers[i - 1].clone(); + for _j in 0..WINDOW_BITS { + power_i = self.curve_double(&power_i); + } + powers.push(power_i); + } + + MulPrecomputationTarget { powers } + } + + /*fn to_digits(&mut self, x: &NonNativeTarget) -> Vec> { + debug_assert!( + 64 % WINDOW_BITS == 0, + "For simplicity, only power-of-two window sizes are handled for now" + ); + + let base = self.constant_nonnative(C::ScalarField::from_canonical_u64(BASE as u64)); + + let num_digits = digits_per_scalar::(); + let mut digits = Vec::with_capacity(num_digits); + + let (rest, limb) = self.div_rem_nonnative(&x, &base); + for _ in 0..num_digits { + digits.push(limb); + + let (rest, limb) = self.div_rem_nonnative(&rest, &base); + } + + digits + } + + pub fn mul_with_precomputation( + &mut self, + p: &AffinePointTarget, + n: &NonNativeTarget, + precomputation: MulPrecomputationTarget, + ) -> AffinePointTarget { + // Yao's method; see https://koclab.cs.ucsb.edu/teaching/ecc/eccPapers/Doche-ch09.pdf + let precomputed_powers = precomputation.powers; + + let digits = self.to_digits(n); + + + }*/ + + pub fn curve_scalar_mul( + &mut self, + p: &AffinePointTarget, + n: &NonNativeTarget, + ) -> AffinePointTarget { + let one = self.constant_nonnative(C::BaseField::ONE); + let two = self.constant_nonnative(C::ScalarField::TWO); + let num_bits = C::ScalarField::BITS; + + // Result starts at p, which is later subtracted, because we don't support arithmetic with the zero point. + let mut result = self.add_virtual_affine_point_target(); + self.connect_affine_point(p, &result); + let mut two_i_times_p = self.add_virtual_affine_point_target(); + self.connect_affine_point(p, &two_i_times_p); + + let mut cur_n = self.add_virtual_nonnative_target::(); + for _i in 0..num_bits { + let (bit_scalar, new_n) = self.div_rem_nonnative(&cur_n, &two); + let bit_biguint = self.nonnative_to_biguint(&bit_scalar); + let bit = self.biguint_to_nonnative::(&bit_biguint); + let not_bit = self.sub_nonnative(&one, &bit); + + let result_plus_2_i_p = self.curve_add(&result, &two_i_times_p); + + let result_x = result.x; + let result_y = result.y; + let result_plus_2_i_p_x = result_plus_2_i_p.x; + let result_plus_2_i_p_y = result_plus_2_i_p.y; + + let new_x_if_bit = self.mul_nonnative(&bit, &result_plus_2_i_p_x); + let new_x_if_not_bit = self.mul_nonnative(¬_bit, &result_x); + let new_y_if_bit = self.mul_nonnative(&bit, &result_plus_2_i_p_y); + let new_y_if_not_bit = self.mul_nonnative(¬_bit, &result_y); + + let new_x = self.add_nonnative(&new_x_if_bit, &new_x_if_not_bit); + let new_y = self.add_nonnative(&new_y_if_bit, &new_y_if_not_bit); + + result = AffinePointTarget { x: new_x, y: new_y }; + + two_i_times_p = self.curve_double(&two_i_times_p); + cur_n = new_n; + } + + // Subtract off result's intial value of p. + let neg_p = self.curve_neg(&p); + result = self.curve_add(&result, &neg_p); + + result + } } mod tests { @@ -200,19 +319,53 @@ mod tests { let mut builder = CircuitBuilder::::new(config); let g = Secp256K1::GENERATOR_AFFINE; - let neg_g = g.neg(); let g_target = builder.constant_affine_point(g); let neg_g_target = builder.curve_neg(&g_target); let double_g = g.double(); - let double_g_other_target = builder.constant_affine_point(double_g); - builder.curve_assert_valid(&double_g_other_target); + let double_g_expected = builder.constant_affine_point(double_g); + builder.curve_assert_valid(&double_g_expected); - let double_g_target = builder.curve_double(&g_target, g); - let double_neg_g_target = builder.curve_double(&neg_g_target, neg_g); + let double_neg_g = (-g).double(); + let double_neg_g_expected = builder.constant_affine_point(double_neg_g); + builder.curve_assert_valid(&double_neg_g_expected); - builder.curve_assert_valid(&double_g_target); - builder.curve_assert_valid(&double_neg_g_target); + let double_g_actual = builder.curve_double(&g_target); + let double_neg_g_actual = builder.curve_double(&neg_g_target); + builder.curve_assert_valid(&double_g_actual); + builder.curve_assert_valid(&double_neg_g_actual); + + builder.connect_affine_point(&double_g_expected, &double_g_actual); + builder.connect_affine_point(&double_neg_g_expected, &double_neg_g_actual); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_curve_add() -> Result<()> { + type F = GoldilocksField; + const D: usize = 4; + + let config = CircuitConfig::standard_recursion_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let double_g = g.double(); + let g_plus_2g = (g + double_g).to_affine(); + let g_plus_2g_expected = builder.constant_affine_point(g_plus_2g); + builder.curve_assert_valid(&g_plus_2g_expected); + + let g_target = builder.constant_affine_point(g); + let double_g_target = builder.curve_double(&g_target); + let g_plus_2g_actual = builder.curve_add(&g_target, &double_g_target); + builder.curve_assert_valid(&g_plus_2g_actual); + + builder.connect_affine_point(&g_plus_2g_expected, &g_plus_2g_actual); let data = builder.build(); let proof = data.prove(pw).unwrap(); diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 90735a61..19d86658 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -17,6 +17,14 @@ pub struct NonNativeTarget { } impl, const D: usize> CircuitBuilder { + fn num_nonnative_limbs() -> usize { + let ff_size = FF::order(); + let f_size = F::order(); + let num_limbs = ((ff_size + f_size.clone() - BigUint::one()) / f_size).to_u32_digits()[0]; + + num_limbs as usize + } + pub fn biguint_to_nonnative(&mut self, x: &BigUintTarget) -> NonNativeTarget { NonNativeTarget { value: x.clone(), @@ -42,6 +50,16 @@ impl, const D: usize> CircuitBuilder { self.connect_biguint(&lhs.value, &rhs.value); } + pub fn add_virtual_nonnative_target(&mut self) -> NonNativeTarget { + let num_limbs = Self::num_nonnative_limbs::(); + let value = self.add_virtual_biguint_target(num_limbs); + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + // Add two `NonNativeTarget`s. pub fn add_nonnative( &mut self, @@ -106,6 +124,20 @@ impl, const D: usize> CircuitBuilder { inv } + pub fn div_rem_nonnative( + &mut self, + x: &NonNativeTarget, + y: &NonNativeTarget, + ) -> (NonNativeTarget, NonNativeTarget) { + let x_biguint = self.nonnative_to_biguint(x); + let y_biguint = self.nonnative_to_biguint(y); + + let (div_biguint, rem_biguint) = self.div_rem_biguint(&x_biguint, &y_biguint); + let div = self.biguint_to_nonnative(&div_biguint); + let rem = self.biguint_to_nonnative(&rem_biguint); + (div, rem) + } + /// Returns `x % |FF|` as a `NonNativeTarget`. fn reduce(&mut self, x: &BigUintTarget) -> NonNativeTarget { let modulus = FF::order();