From 772ff8d69ab7fb4e1d64dca10cec1a6739f0694e Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 25 Feb 2022 16:30:01 +0100 Subject: [PATCH] Works --- plonky2/src/gadgets/curve_msm.rs | 155 +++++++++++++++++++++++++ plonky2/src/gadgets/ecdsa.rs | 5 +- plonky2/src/gadgets/mod.rs | 1 + plonky2/src/gadgets/split_nonnative.rs | 11 ++ 4 files changed, 170 insertions(+), 2 deletions(-) create mode 100644 plonky2/src/gadgets/curve_msm.rs diff --git a/plonky2/src/gadgets/curve_msm.rs b/plonky2/src/gadgets/curve_msm.rs new file mode 100644 index 00000000..d0add327 --- /dev/null +++ b/plonky2/src/gadgets/curve_msm.rs @@ -0,0 +1,155 @@ +use num::BigUint; +use plonky2_field::extension_field::Extendable; + +use crate::curve::curve_types::{Curve, CurveScalar}; +use crate::field::field_types::Field; +use crate::gadgets::curve::AffinePointTarget; +use crate::gadgets::nonnative::NonNativeTarget; +use crate::hash::hash_types::RichField; +use crate::hash::keccak::KeccakHash; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::config::{GenericHashOut, Hasher}; + +impl, const D: usize> CircuitBuilder { + /// Computes `n*p + m*q`. + pub fn curve_msm( + &mut self, + p: &AffinePointTarget, + q: &AffinePointTarget, + n: &NonNativeTarget, + m: &NonNativeTarget, + ) -> AffinePointTarget { + let bits_n = self.split_nonnative_to_bits(n); + let bits_m = self.split_nonnative_to_bits(m); + assert_eq!(bits_n.len(), bits_m.len()); + + let sum = self.curve_add(p, q); + let precomputation = vec![p.clone(), p.clone(), q.clone(), sum]; + + let two = self.two(); + let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); + let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( + &GenericHashOut::::to_bytes(&hash_0), + )); + let starting_point = CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE; + let starting_point_multiplied = + (0..C::ScalarField::BITS).fold(starting_point, |acc, _| acc.double()); + + let zero = self.zero(); + let mut result = self.constant_affine_point(starting_point.to_affine()); + for (b_n, b_m) in bits_n.into_iter().zip(bits_m).rev() { + result = self.curve_double(&result); + let index = self.mul_add(two, b_m.target, b_n.target); + let r = self.random_access_curve_points(index, precomputation.clone()); + let is_zero = self.is_equal(index, zero); + let should_add = self.not(is_zero); + result = self.curve_conditional_add(&result, &r, should_add); + } + let to_subtract = self.constant_affine_point(starting_point_multiplied.to_affine()); + let to_add = self.curve_neg(&to_subtract); + result = self.curve_add(&result, &to_add); + + result + } +} + +#[cfg(test)] +mod tests { + use std::ops::Neg; + + use anyhow::Result; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::secp256k1::Secp256K1; + use crate::field::field_types::Field; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use crate::plonk::verifier::verify; + + #[test] + fn test_yo() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let p = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let q = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let n = Secp256K1Scalar::rand(); + let m = Secp256K1Scalar::rand(); + + let res = + (CurveScalar(n) * p.to_projective() + CurveScalar(m) * q.to_projective()).to_affine(); + let res_expected = builder.constant_affine_point(res); + builder.curve_assert_valid(&res_expected); + + let p_target = builder.constant_affine_point(p); + let q_target = builder.constant_affine_point(q); + let n_target = builder.constant_nonnative(n); + let m_target = builder.constant_nonnative(m); + + let res_target = builder.curve_msm(&p_target, &q_target, &n_target, &m_target); + builder.curve_assert_valid(&res_target); + + builder.connect_affine_point(&res_target, &res_expected); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_ya() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let p = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let q = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let n = Secp256K1Scalar::rand(); + let m = Secp256K1Scalar::rand(); + + let res = + (CurveScalar(n) * p.to_projective() + CurveScalar(m) * q.to_projective()).to_affine(); + let res_expected = builder.constant_affine_point(res); + builder.curve_assert_valid(&res_expected); + + let p_target = builder.constant_affine_point(p); + let q_target = builder.constant_affine_point(q); + let n_target = builder.constant_nonnative(n); + let m_target = builder.constant_nonnative(m); + + // let res0_target = builder.curve_scalar_mul_windowed(&p_target, &n_target); + // let res1_target = builder.curve_scalar_mul_windowed(&q_target, &m_target); + let res0_target = builder.curve_scalar_mul(&p_target, &n_target); + let res1_target = builder.curve_scalar_mul(&q_target, &m_target); + let res_target = builder.curve_add(&res0_target, &res1_target); + builder.curve_assert_valid(&res_target); + + builder.connect_affine_point(&res_target, &res_expected); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } +} diff --git a/plonky2/src/gadgets/ecdsa.rs b/plonky2/src/gadgets/ecdsa.rs index 0a95e189..5f4c4ff1 100644 --- a/plonky2/src/gadgets/ecdsa.rs +++ b/plonky2/src/gadgets/ecdsa.rs @@ -35,8 +35,8 @@ impl, const D: usize> CircuitBuilder { let u2 = self.mul_nonnative(&r, &c); let g = self.constant_affine_point(C::GENERATOR_AFFINE); - let point1 = self.curve_scalar_mul(&g, &u1); - let point2 = self.curve_scalar_mul(&pk.0, &u2); + let point1 = self.curve_scalar_mul_windowed(&g, &u1); + let point2 = self.curve_scalar_mul_windowed(&pk.0, &u2); let point = self.curve_add(&point1, &point2); let x = NonNativeTarget:: { @@ -97,6 +97,7 @@ mod tests { builder.verify_message(msg_target, sig_target, pk_target); + dbg!(builder.num_gates()); let data = builder.build::(); let proof = data.prove(pw).unwrap(); verify(proof, &data.verifier_only, &data.common) diff --git a/plonky2/src/gadgets/mod.rs b/plonky2/src/gadgets/mod.rs index d9c93db3..e35afeed 100644 --- a/plonky2/src/gadgets/mod.rs +++ b/plonky2/src/gadgets/mod.rs @@ -5,6 +5,7 @@ pub mod biguint; pub mod curve; pub mod curve_windowed_mul; // pub mod curve_msm; +pub mod curve_msm; pub mod ecdsa; pub mod glv; pub mod hash; diff --git a/plonky2/src/gadgets/split_nonnative.rs b/plonky2/src/gadgets/split_nonnative.rs index 70661506..18fc0264 100644 --- a/plonky2/src/gadgets/split_nonnative.rs +++ b/plonky2/src/gadgets/split_nonnative.rs @@ -35,6 +35,17 @@ impl, const D: usize> CircuitBuilder { .collect() } + pub fn split_nonnative_to_2_bit_limbs( + &mut self, + val: &NonNativeTarget, + ) -> Vec { + val.value + .limbs + .iter() + .flat_map(|&l| self.split_le_base::<4>(l.0, 16)) + .collect() + } + // Note: assumes its inputs are 4-bit limbs, and does not range-check. pub fn recombine_nonnative_4_bit_limbs( &mut self,