From 2ec3ea8634e7f25d9c5d3f70f5511c36048d663c Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 18 Nov 2021 15:48:28 -0800 Subject: [PATCH] new curve_mul --- src/gadgets/curve.rs | 52 +++++++++++++++++++++++++++++++--------- src/gadgets/nonnative.rs | 35 ++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 12 deletions(-) diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs index eda0e5e0..0982d5f9 100644 --- a/src/gadgets/curve.rs +++ b/src/gadgets/curve.rs @@ -201,18 +201,18 @@ impl, const D: usize> CircuitBuilder { let two = self.constant_nonnative(C::ScalarField::TWO); let num_bits = C::ScalarField::BITS; + let bits = self.split_nonnative_to_bits(&n); + let bits_as_base: Vec> = + bits.iter().map(|b| self.bool_to_nonnative(b)).collect(); + // Result starts at p, which is later subtracted, because we don't support arithmetic with the zero point. let mut result = self.add_virtual_affine_point_target(); self.connect_affine_point(p, &result); let mut two_i_times_p = self.add_virtual_affine_point_target(); self.connect_affine_point(p, &two_i_times_p); - let mut cur_n = self.add_virtual_nonnative_target::(); - for _i in 0..num_bits { - let (bit_scalar, new_n) = self.div_rem_nonnative(&cur_n, &two); - let bit_biguint = self.nonnative_to_biguint(&bit_scalar); - let bit = self.biguint_to_nonnative::(&bit_biguint); - let not_bit = self.sub_nonnative(&one, &bit); + for bit in bits_as_base.iter() { + let not_bit = self.sub_nonnative(&one, bit); let result_plus_2_i_p = self.curve_add(&result, &two_i_times_p); @@ -221,9 +221,9 @@ impl, const D: usize> CircuitBuilder { let result_plus_2_i_p_x = result_plus_2_i_p.x; let result_plus_2_i_p_y = result_plus_2_i_p.y; - let new_x_if_bit = self.mul_nonnative(&bit, &result_plus_2_i_p_x); + let new_x_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p_x); let new_x_if_not_bit = self.mul_nonnative(¬_bit, &result_x); - let new_y_if_bit = self.mul_nonnative(&bit, &result_plus_2_i_p_y); + let new_y_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p_y); let new_y_if_not_bit = self.mul_nonnative(¬_bit, &result_y); let new_x = self.add_nonnative(&new_x_if_bit, &new_x_if_not_bit); @@ -232,7 +232,6 @@ impl, const D: usize> CircuitBuilder { result = AffinePointTarget { x: new_x, y: new_y }; two_i_times_p = self.curve_double(&two_i_times_p); - cur_n = new_n; } // Subtract off result's intial value of p. @@ -244,15 +243,16 @@ impl, const D: usize> CircuitBuilder { } mod tests { - use std::ops::Neg; + use std::ops::{Mul, Neg}; use anyhow::Result; - use crate::curve::curve_types::{AffinePoint, Curve}; + use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; use crate::curve::secp256k1::Secp256K1; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; use crate::field::secp256k1_base::Secp256K1Base; + use crate::field::secp256k1_scalar::Secp256K1Scalar; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; @@ -372,4 +372,34 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + + #[test] + fn test_curve_mul() -> Result<()> { + type F = GoldilocksField; + const D: usize = 4; + + let config = CircuitConfig::standard_recursion_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let five = Secp256K1Scalar::from_canonical_usize(5); + let five_scalar = CurveScalar::(five); + let five_g = (five_scalar * g.to_projective()).to_affine(); + let five_g_expected = builder.constant_affine_point(five_g); + builder.curve_assert_valid(&five_g_expected); + + let g_target = builder.constant_affine_point(g); + let five_target = builder.constant_nonnative(five); + let five_g_actual = builder.curve_scalar_mul(&g_target, &five_target); + builder.curve_assert_valid(&five_g_actual); + + builder.connect_affine_point(&five_g_expected, &five_g_actual); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } } diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index 19d86658..88250093 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -4,9 +4,10 @@ use num::{BigUint, One, Zero}; use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; +use crate::gadgets::arithmetic_u32::U32Target; use crate::gadgets::biguint::BigUintTarget; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; -use crate::iop::target::Target; +use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; @@ -158,6 +159,38 @@ impl, const D: usize> CircuitBuilder { let x_biguint = self.nonnative_to_biguint(x); self.reduce(&x_biguint) } + + pub fn bool_to_nonnative(&mut self, b: &BoolTarget) -> NonNativeTarget { + let limbs = vec![U32Target(b.target)]; + let value = BigUintTarget { limbs }; + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + + // Split a nonnative field element to bits. + pub fn split_nonnative_to_bits( + &mut self, + x: &NonNativeTarget, + ) -> Vec { + let num_limbs = x.value.num_limbs(); + let mut result = Vec::with_capacity(num_limbs * 32); + + for i in 0..num_limbs { + let limb = x.value.get_limb(i); + let bit_targets = self.split_le_base::<2>(limb.0, 32); + let mut bits: Vec<_> = bit_targets + .iter() + .map(|&t| BoolTarget::new_unsafe(t)) + .collect(); + + result.append(&mut bits); + } + + result + } } #[derive(Debug)]