diff --git a/plonky2/src/curve/glv.rs b/plonky2/src/curve/glv.rs index ded0d0b3..a2786e31 100644 --- a/plonky2/src/curve/glv.rs +++ b/plonky2/src/curve/glv.rs @@ -1,4 +1,5 @@ use num::rational::Ratio; +use num::BigUint; use plonky2_field::field_types::Field; use plonky2_field::secp256k1_base::Secp256K1Base; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; @@ -30,26 +31,46 @@ const A2: Secp256K1Scalar = Secp256K1Scalar([6323353552219852760, 14980988506747 const B2: Secp256K1Scalar = Secp256K1Scalar([16747920425669159701, 3496713202691238861, 0, 0]); -pub fn decompose_secp256k1_scalar(k: Secp256K1Scalar) -> (Secp256K1Scalar, Secp256K1Scalar) { +pub fn decompose_secp256k1_scalar( + k: Secp256K1Scalar, +) -> (Secp256K1Scalar, Secp256K1Scalar, bool, bool) { let p = Secp256K1Scalar::order(); let c1_biguint = Ratio::new(B2.to_biguint() * k.to_biguint(), p.clone()) .round() .to_integer(); let c1 = Secp256K1Scalar::from_biguint(c1_biguint); - let c2_biguint = Ratio::new(MINUS_B1.to_biguint() * k.to_biguint(), p) + let c2_biguint = Ratio::new(MINUS_B1.to_biguint() * k.to_biguint(), p.clone()) .round() .to_integer(); let c2 = Secp256K1Scalar::from_biguint(c2_biguint); - let k1 = k - c1 * A1 - c2 * A2; - let k2 = c1 * MINUS_B1 - c2 * B2; - debug_assert!(k1 + S * k2 == k); - (k1, k2) + let k1_raw = k - c1 * A1 - c2 * A2; + let k2_raw = c1 * MINUS_B1 - c2 * B2; + debug_assert!(k1_raw + S * k2_raw == k); + + let two = BigUint::from_slice(&[2]); + let k1_neg = k1_raw.to_biguint() > p.clone() / two.clone(); + let k1 = if k1_neg { + Secp256K1Scalar::from_biguint(p.clone() - k1_raw.to_biguint()) + } else { + k1_raw + }; + let k2_neg = k2_raw.to_biguint() > p.clone() / two.clone(); + let k2 = if k2_neg { + Secp256K1Scalar::from_biguint(p.clone() - k2_raw.to_biguint()) + } else { + k2_raw + }; + + (k1, k2, k1_neg, k2_neg) } pub fn glv_mul(p: ProjectivePoint, k: Secp256K1Scalar) -> ProjectivePoint { - let (k1, k2) = decompose_secp256k1_scalar(k); - assert!(k1 + S * k2 == k); + let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); + let one = Secp256K1Scalar::ONE; + /*let m1 = if k1_neg { -one } else { one }; + let m2 = if k2_neg { -one } else { one }; + assert!(k1 * m1 + S * k2 * m2 == k);*/ let p_affine = p.to_affine(); let sp = AffinePoint:: { @@ -58,7 +79,10 @@ pub fn glv_mul(p: ProjectivePoint, k: Secp256K1Scalar) -> ProjectiveP zero: p_affine.zero, }; - msm_parallel(&[k1, k2], &[p, sp.to_projective()], 5) + let first = if k1_neg { p.neg() } else { p }; + let second = if k2_neg { sp.to_projective().neg() } else { sp.to_projective() }; + + msm_parallel(&[k1, k2], &[first, second], 5) } #[cfg(test)] @@ -74,9 +98,12 @@ mod tests { #[test] fn test_glv_decompose() -> Result<()> { let k = Secp256K1Scalar::rand(); - let (k1, k2) = decompose_secp256k1_scalar(k); + let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); + let one = Secp256K1Scalar::ONE; + let m1 = if k1_neg { -one } else { one }; + let m2 = if k2_neg { -one } else { one }; - assert!(k1 + S * k2 == k); + assert!(k1 * m1 + S * k2 * m2 == k); Ok(()) } diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index 2b27c120..aa8c4112 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -71,6 +71,24 @@ impl, const D: usize> CircuitBuilder { } } + pub fn curve_conditional_neg( + &mut self, + p: &AffinePointTarget, + b: BoolTarget, + ) -> AffinePointTarget { + let not_b = self.not(b); + let neg = self.curve_neg(p); + let x_if_true = self.mul_nonnative_by_bool(&neg.x, b); + let y_if_true = self.mul_nonnative_by_bool(&neg.y, b); + let x_if_false = self.mul_nonnative_by_bool(&p.x, not_b); + let y_if_false = self.mul_nonnative_by_bool(&p.y, not_b); + + let x = self.add_nonnative(&x_if_true, &x_if_false); + let y = self.add_nonnative(&y_if_true, &y_if_false); + + AffinePointTarget { x, y } + } + pub fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget { let AffinePointTarget { x, y } = p; let double_y = self.add_nonnative(y, y); diff --git a/plonky2/src/gadgets/curve_windowed_mul.rs b/plonky2/src/gadgets/curve_windowed_mul.rs index 2f0516dd..f4cebe0e 100644 --- a/plonky2/src/gadgets/curve_windowed_mul.rs +++ b/plonky2/src/gadgets/curve_windowed_mul.rs @@ -18,18 +18,18 @@ use crate::plonk::config::{GenericHashOut, Hasher}; const WINDOW_SIZE: usize = 4; impl, const D: usize> CircuitBuilder { - // TODO: fix if p is the generator pub fn precompute_window( &mut self, p: &AffinePointTarget, ) -> Vec> { + let g = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine(); let neg = { - let mut g = C::GENERATOR_AFFINE; - g.y = -g.y; - self.constant_affine_point(g) + let mut neg = g; + neg.y = -neg.y; + self.constant_affine_point(neg) }; - let mut multiples = vec![self.constant_affine_point(C::GENERATOR_AFFINE)]; + let mut multiples = vec![self.constant_affine_point(g)]; for i in 1..1 << WINDOW_SIZE { multiples.push(self.curve_add(p, &multiples[i - 1])); } diff --git a/plonky2/src/gadgets/glv.rs b/plonky2/src/gadgets/glv.rs index 5ac3fb93..5614de55 100644 --- a/plonky2/src/gadgets/glv.rs +++ b/plonky2/src/gadgets/glv.rs @@ -10,7 +10,7 @@ use crate::gadgets::curve::AffinePointTarget; use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::RichField; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; -use crate::iop::target::Target; +use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; @@ -25,18 +25,24 @@ impl, const D: usize> CircuitBuilder { ) -> ( NonNativeTarget, NonNativeTarget, + BoolTarget, + BoolTarget, ) { - let k1 = self.add_virtual_nonnative_target::(); - let k2 = self.add_virtual_nonnative_target::(); + let k1 = self.add_virtual_nonnative_target_sized::(4); + let k2 = self.add_virtual_nonnative_target_sized::(4); + let k1_neg = self.add_virtual_bool_target(); + let k2_neg = self.add_virtual_bool_target(); self.add_simple_generator(GLVDecompositionGenerator:: { k: k.clone(), k1: k1.clone(), k2: k2.clone(), + k1_neg: k1_neg.clone(), + k2_neg: k2_neg.clone(), _phantom: PhantomData, }); - (k1, k2) + (k1, k2, k1_neg, k2_neg) } pub fn glv_mul( @@ -44,7 +50,7 @@ impl, const D: usize> CircuitBuilder { p: &AffinePointTarget, k: &NonNativeTarget, ) -> AffinePointTarget { - let (k1, k2) = self.decompose_secp256k1_scalar(k); + let (k1, k2, k1_neg, k2_neg) = self.decompose_secp256k1_scalar(k); let beta = self.secp256k1_glv_beta(); let beta_px = self.mul_nonnative(&beta, &p.x); @@ -54,9 +60,11 @@ impl, const D: usize> CircuitBuilder { }; let part1 = self.curve_scalar_mul_windowed(p, &k1); + let part1_neg = self.curve_conditional_neg(&part1, k1_neg); let part2 = self.curve_scalar_mul_windowed(&sp, &k2); + let part2_neg = self.curve_conditional_neg(&part2, k2_neg); - self.curve_add(&part1, &part2) + self.curve_add(&part1_neg, &part2_neg) } } @@ -65,6 +73,8 @@ struct GLVDecompositionGenerator, const D: usize> { k: NonNativeTarget, k1: NonNativeTarget, k2: NonNativeTarget, + k1_neg: BoolTarget, + k2_neg: BoolTarget, _phantom: PhantomData, } @@ -77,10 +87,12 @@ impl, const D: usize> SimpleGenerator fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { let k = witness.get_nonnative_target(self.k.clone()); - let (k1, k2) = decompose_secp256k1_scalar(k); + let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); out_buffer.set_nonnative_target(self.k1.clone(), k1); out_buffer.set_nonnative_target(self.k2.clone(), k2); + out_buffer.set_bool_target(self.k1_neg.clone(), k1_neg); + out_buffer.set_bool_target(self.k2_neg.clone(), k2_neg); } } @@ -100,7 +112,7 @@ mod tests { use crate::plonk::verifier::verify; #[test] - fn test_glv() -> Result<()> { + fn test_glv_gadget() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; diff --git a/plonky2/src/gadgets/nonnative.rs b/plonky2/src/gadgets/nonnative.rs index 44969ac9..046931d2 100644 --- a/plonky2/src/gadgets/nonnative.rs +++ b/plonky2/src/gadgets/nonnative.rs @@ -63,6 +63,18 @@ impl, const D: usize> CircuitBuilder { } } + pub fn add_virtual_nonnative_target_sized( + &mut self, + num_limbs: usize, + ) -> NonNativeTarget { + let value = self.add_virtual_biguint_target(num_limbs); + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + pub fn add_nonnative( &mut self, a: &NonNativeTarget,