diff --git a/plonky2/src/curve/glv.rs b/plonky2/src/curve/glv.rs index e4f84886..42444f48 100644 --- a/plonky2/src/curve/glv.rs +++ b/plonky2/src/curve/glv.rs @@ -3,6 +3,10 @@ use plonky2_field::field_types::Field; use plonky2_field::secp256k1_base::Secp256K1Base; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; +use crate::curve::curve_msm::msm_parallel; +use crate::curve::curve_types::{ProjectivePoint, AffinePoint}; +use crate::curve::secp256k1::Secp256K1; + pub const BETA: Secp256K1Base = Secp256K1Base([ 13923278643952681454, 11308619431505398165, @@ -43,15 +47,28 @@ pub fn decompose_secp256k1_scalar(k: Secp256K1Scalar) -> (Secp256K1Scalar, Secp2 (k1, k2) } +pub fn glv_mul(p: ProjectivePoint, k: Secp256K1Scalar) -> ProjectivePoint { + let (k1, k2) = decompose_secp256k1_scalar(k); + assert!(k1 + S * k2 == k); + + let p_affine = p.to_affine(); + let sp = AffinePoint:: { + x: p_affine.x * BETA, + y: p_affine.y, + zero: p_affine.zero, + }; + + msm_parallel(&[k1, k2], &[p, sp.to_projective()], 5) +} + #[cfg(test)] mod tests { use anyhow::Result; use plonky2_field::field_types::Field; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - use crate::curve::curve_msm::msm_parallel; - use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; - use crate::curve::glv::{decompose_secp256k1_scalar, BETA, S}; + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::glv::{decompose_secp256k1_scalar, glv_mul, S}; use crate::curve::secp256k1::Secp256K1; #[test] @@ -68,20 +85,11 @@ mod tests { fn test_glv_mul() -> Result<()> { for _ in 0..20 { let k = Secp256K1Scalar::rand(); - let (k1, k2) = decompose_secp256k1_scalar(k); - assert!(k1 + S * k2 == k); + let p = CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE; - let p = (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE) - .to_affine(); - let sp = AffinePoint:: { - x: p.x * BETA, - y: p.y, - zero: p.zero, - }; - - let kp = CurveScalar(k) * p.to_projective(); - let glv = msm_parallel(&[k1, k2], &[p.to_projective(), sp.to_projective()], 5); + let kp = CurveScalar(k) * p; + let glv = glv_mul(p, k); assert!(kp == glv); }