From efb074b247bf6b7f86986cbe50fdb729f79f201d Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 25 Feb 2022 17:21:35 +0100 Subject: [PATCH] Works with 2 --- plonky2/src/gadgets/curve_msm.rs | 52 +++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/plonky2/src/gadgets/curve_msm.rs b/plonky2/src/gadgets/curve_msm.rs index d0add327..e8db6fdd 100644 --- a/plonky2/src/gadgets/curve_msm.rs +++ b/plonky2/src/gadgets/curve_msm.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use num::BigUint; use plonky2_field::extension_field::Extendable; @@ -19,27 +21,51 @@ impl, const D: usize> CircuitBuilder { n: &NonNativeTarget, m: &NonNativeTarget, ) -> AffinePointTarget { - let bits_n = self.split_nonnative_to_bits(n); - let bits_m = self.split_nonnative_to_bits(m); - assert_eq!(bits_n.len(), bits_m.len()); + let limbs_n = self.split_nonnative_to_2_bit_limbs(n); + let limbs_m = self.split_nonnative_to_2_bit_limbs(m); + assert_eq!(limbs_n.len(), limbs_m.len()); - let sum = self.curve_add(p, q); - let precomputation = vec![p.clone(), p.clone(), q.clone(), sum]; - - let two = self.two(); let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( &GenericHashOut::::to_bytes(&hash_0), )); let starting_point = CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE; + let starting_point_t = self.constant_affine_point(starting_point.to_affine()); + let neg = { + let mut neg = starting_point.to_affine(); + neg.y = -neg.y; + self.constant_affine_point(neg) + }; + + let mut precomputation = vec![p.clone(); 16]; + let mut cur_p = starting_point_t.clone(); + let mut cur_q = starting_point_t.clone(); + for i in 0..4 { + precomputation[i] = cur_p.clone(); + precomputation[4 * i] = cur_q.clone(); + cur_p = self.curve_add(&cur_p, p); + cur_q = self.curve_add(&cur_q, q); + } + for i in 1..4 { + precomputation[i] = self.curve_add(&precomputation[i], &neg); + precomputation[4 * i] = self.curve_add(&precomputation[4 * i], &neg); + } + for i in 1..4 { + for j in 1..4 { + precomputation[i + 4 * j] = + self.curve_add(&precomputation[i], &precomputation[4 * j]); + } + } + + let four = self.constant(F::from_canonical_usize(4)); let starting_point_multiplied = (0..C::ScalarField::BITS).fold(starting_point, |acc, _| acc.double()); let zero = self.zero(); let mut result = self.constant_affine_point(starting_point.to_affine()); - for (b_n, b_m) in bits_n.into_iter().zip(bits_m).rev() { - result = self.curve_double(&result); - let index = self.mul_add(two, b_m.target, b_n.target); + for (limb_n, limb_m) in limbs_n.into_iter().zip(limbs_m).rev() { + result = self.curve_repeated_double(&result, 2); + let index = self.mul_add(four, limb_m, limb_n); let r = self.random_access_curve_points(index, precomputation.clone()); let is_zero = self.is_equal(index, zero); let should_add = self.not(is_zero); @@ -137,10 +163,8 @@ mod tests { let n_target = builder.constant_nonnative(n); let m_target = builder.constant_nonnative(m); - // let res0_target = builder.curve_scalar_mul_windowed(&p_target, &n_target); - // let res1_target = builder.curve_scalar_mul_windowed(&q_target, &m_target); - let res0_target = builder.curve_scalar_mul(&p_target, &n_target); - let res1_target = builder.curve_scalar_mul(&q_target, &m_target); + let res0_target = builder.curve_scalar_mul_windowed(&p_target, &n_target); + let res1_target = builder.curve_scalar_mul_windowed(&q_target, &m_target); let res_target = builder.curve_add(&res0_target, &res1_target); builder.curve_assert_valid(&res_target);