diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index aa8c4112..a1fc3a8b 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -78,15 +78,12 @@ impl, const D: usize> CircuitBuilder { ) -> AffinePointTarget { let not_b = self.not(b); let neg = self.curve_neg(p); - let x_if_true = self.mul_nonnative_by_bool(&neg.x, b); let y_if_true = self.mul_nonnative_by_bool(&neg.y, b); - let x_if_false = self.mul_nonnative_by_bool(&p.x, not_b); let y_if_false = self.mul_nonnative_by_bool(&p.y, not_b); - let x = self.add_nonnative(&x_if_true, &x_if_false); let y = self.add_nonnative(&y_if_true, &y_if_false); - AffinePointTarget { x, y } + AffinePointTarget { x: p.x.clone(), y } } pub fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget { diff --git a/plonky2/src/gadgets/curve_fixed_base.rs b/plonky2/src/gadgets/curve_fixed_base.rs index 3b826f78..70def6bc 100644 --- a/plonky2/src/gadgets/curve_fixed_base.rs +++ b/plonky2/src/gadgets/curve_fixed_base.rs @@ -11,6 +11,8 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::config::{GenericHashOut, Hasher}; impl, const D: usize> CircuitBuilder { + /// Do windowed fixed-base scalar multiplication, using a 4-bit window. + // TODO: Benchmark other window sizes. pub fn fixed_base_curve_mul( &mut self, base: &AffinePoint, diff --git a/plonky2/src/gadgets/curve_msm.rs b/plonky2/src/gadgets/curve_msm.rs index 43a22da2..df13e8f3 100644 --- a/plonky2/src/gadgets/curve_msm.rs +++ b/plonky2/src/gadgets/curve_msm.rs @@ -3,8 +3,8 @@ use plonky2_field::extension_field::Extendable; use crate::curve::curve_types::{Curve, CurveScalar}; use crate::field::field_types::Field; +use crate::gadgets::biguint::BigUintTarget; use crate::gadgets::curve::AffinePointTarget; -use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::RichField; use crate::hash::keccak::KeccakHash; use crate::plonk::circuit_builder::CircuitBuilder; @@ -16,12 +16,13 @@ impl, const D: usize> CircuitBuilder { &mut self, p: &AffinePointTarget, q: &AffinePointTarget, - n: &NonNativeTarget, - m: &NonNativeTarget, + n: &BigUintTarget, + m: &BigUintTarget, ) -> AffinePointTarget { - let limbs_n = self.split_nonnative_to_2_bit_limbs(n); - let limbs_m = self.split_nonnative_to_2_bit_limbs(m); + let limbs_n = self.split_biguint_to_2_bit_limbs(n); + let limbs_m = self.split_biguint_to_2_bit_limbs(m); assert_eq!(limbs_n.len(), limbs_m.len()); + let num_limbs = limbs_n.len(); let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( @@ -63,8 +64,7 @@ impl, const D: usize> CircuitBuilder { let should_add = self.not(is_zero); result = self.curve_conditional_add(&result, &r, should_add); } - let starting_point_multiplied = - (0..C::ScalarField::BITS).fold(rando, |acc, _| acc.double()); + let starting_point_multiplied = (0..2 * num_limbs).fold(rando, |acc, _| acc.double()); let to_add = self.constant_affine_point(-starting_point_multiplied); result = self.curve_add(&result, &to_add); @@ -74,11 +74,14 @@ impl, const D: usize> CircuitBuilder { #[cfg(test)] mod tests { + use std::str::FromStr; use anyhow::Result; + use num::BigUint; + use plonky2_field::secp256k1_base::Secp256K1Base; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; use crate::curve::secp256k1::Secp256K1; use crate::field::field_types::Field; use crate::iop::witness::PartialWitness; @@ -115,7 +118,7 @@ mod tests { let n_target = builder.constant_nonnative(n); let m_target = builder.constant_nonnative(m); - let res_target = builder.curve_msm(&p_target, &q_target, &n_target, &m_target); + let res_target = builder.curve_msm(&p_target, &q_target, &n_target.value, &m_target.value); builder.curve_assert_valid(&res_target); builder.connect_affine_point(&res_target, &res_expected); @@ -168,4 +171,72 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + + #[test] + fn test_curve_lul() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let p = AffinePoint:: { + x: Secp256K1Base::from_biguint( + BigUint::from_str( + "95702873347299649035220040874584348285675823985309557645567012532974768144045", + ) + .unwrap(), + ), + y: Secp256K1Base::from_biguint( + BigUint::from_str( + "34849299245821426255020320369755722155634282348110887335812955146294938249053", + ) + .unwrap(), + ), + zero: false, + }; + let q = AffinePoint:: { + x: Secp256K1Base::from_biguint( + BigUint::from_str( + "66037057977021147605301350925941983227524093291368248236634649161657340356645", + ) + .unwrap(), + ), + y: Secp256K1Base::from_biguint( + BigUint::from_str( + "80942789991494769168550664638932185697635702317529676703644628861613896422610", + ) + .unwrap(), + ), + zero: false, + }; + + let n = BigUint::from_str("89874493710619023150462632713212469930").unwrap(); + let m = BigUint::from_str("76073901947022186525975758425319149118").unwrap(); + + let res = (CurveScalar(Secp256K1Scalar::from_biguint(n.clone())) * p.to_projective() + + CurveScalar(Secp256K1Scalar::from_biguint(m.clone())) * q.to_projective()) + .to_affine(); + let res_expected = builder.constant_affine_point(res); + builder.curve_assert_valid(&res_expected); + + let p_target = builder.constant_affine_point(p); + let q_target = builder.constant_affine_point(q); + let n_target = builder.constant_biguint(&n); + let m_target = builder.constant_biguint(&m); + + let res_target = builder.curve_msm(&p_target, &q_target, &n_target, &m_target); + builder.curve_assert_valid(&res_target); + + builder.connect_affine_point(&res_target, &res_expected); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } } diff --git a/plonky2/src/gadgets/glv.rs b/plonky2/src/gadgets/glv.rs index b9a3a380..e0a4cfaa 100644 --- a/plonky2/src/gadgets/glv.rs +++ b/plonky2/src/gadgets/glv.rs @@ -43,6 +43,8 @@ impl, const D: usize> CircuitBuilder { _phantom: PhantomData, }); + // debug_assert!(k1_raw + S * k2_raw == k); + (k1, k2, k1_neg, k2_neg) } @@ -60,24 +62,9 @@ impl, const D: usize> CircuitBuilder { y: p.y.clone(), }; - // let part1 = self.curve_scalar_mul_windowed(p, &k1); - // let part1_neg = self.curve_conditional_neg(&part1, k1_neg); - // let part2 = self.curve_scalar_mul_windowed(&sp, &k2); - // let part2_neg = self.curve_conditional_neg(&part2, k2_neg); - // - // self.curve_add(&part1_neg, &part2_neg) - // dbg!(k1.value.limbs.len()); - // dbg!(k2.value.limbs.len()); let p_neg = self.curve_conditional_neg(&p, k1_neg); let sp_neg = self.curve_conditional_neg(&sp, k2_neg); - // let yo = self.curve_scalar_mul_windowed(&p_neg, &k1); - // let ya = self.curve_scalar_mul_windowed(&sp_neg, &k2); - // dbg!(&yo); - // dbg!(&ya); - // self.connect_affine_point(&part1_neg, &yo); - // self.connect_affine_point(&part2_neg, &ya); - self.curve_msm(&p_neg, &sp_neg, &k1, &k2) - // self.curve_add(&yo, &ya) + self.curve_msm(&p_neg, &sp_neg, &k1.value, &k2.value) } } @@ -118,7 +105,7 @@ mod tests { use crate::curve::curve_types::{Curve, CurveScalar}; use crate::curve::glv::glv_mul; use crate::curve::secp256k1::Secp256K1; - use crate::iop::witness::{PartialWitness, Witness}; + use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; @@ -153,40 +140,4 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } - - #[test] - fn test_wtf() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let mut pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let rando = - (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); - let randot = builder.constant_affine_point(rando); - - let scalar = Secp256K1Scalar::rand(); - let scalar_target = builder.constant_nonnative(scalar); - - let tr = builder.add_virtual_bool_target(); - pw.set_bool_target(tr, false); - - let randotneg = builder.curve_conditional_neg(&randot, tr); - let y = builder.curve_scalar_mul_windowed(&randotneg, &scalar_target); - - let yy = builder.curve_scalar_mul_windowed(&randot, &scalar_target); - let yy = builder.curve_conditional_neg(&yy, tr); - - builder.connect_affine_point(&y, &yy); - - dbg!(builder.num_gates()); - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common) - } } diff --git a/plonky2/src/gadgets/nonnative.rs b/plonky2/src/gadgets/nonnative.rs index 910915d0..73bc0ad3 100644 --- a/plonky2/src/gadgets/nonnative.rs +++ b/plonky2/src/gadgets/nonnative.rs @@ -454,7 +454,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat let b_biguint = b.to_canonical_biguint(); let modulus = FF::order(); - let (diff_biguint, overflow) = if a_biguint > b_biguint { + let (diff_biguint, overflow) = if a_biguint >= b_biguint { (a_biguint - b_biguint, false) } else { (modulus + a_biguint - b_biguint, true) diff --git a/plonky2/src/gadgets/split_nonnative.rs b/plonky2/src/gadgets/split_nonnative.rs index 18fc0264..becf1177 100644 --- a/plonky2/src/gadgets/split_nonnative.rs +++ b/plonky2/src/gadgets/split_nonnative.rs @@ -35,12 +35,8 @@ impl, const D: usize> CircuitBuilder { .collect() } - pub fn split_nonnative_to_2_bit_limbs( - &mut self, - val: &NonNativeTarget, - ) -> Vec { - val.value - .limbs + pub fn split_biguint_to_2_bit_limbs(&mut self, val: &BigUintTarget) -> Vec { + val.limbs .iter() .flat_map(|&l| self.split_le_base::<4>(l.0, 16)) .collect() diff --git a/plonky2/src/iop/generator.rs b/plonky2/src/iop/generator.rs index 1569e889..4dcd11da 100644 --- a/plonky2/src/iop/generator.rs +++ b/plonky2/src/iop/generator.rs @@ -89,6 +89,68 @@ pub(crate) fn generate_partial_witness< } pending_generator_indices = next_pending_generator_indices; + // for t in [ + // Target::VirtualTarget { index: 57934 }, + // Target::VirtualTarget { index: 57935 }, + // Target::VirtualTarget { index: 57936 }, + // Target::VirtualTarget { index: 57937 }, + // Target::VirtualTarget { index: 57938 }, + // Target::VirtualTarget { index: 57939 }, + // Target::VirtualTarget { index: 57940 }, + // Target::VirtualTarget { index: 57941 }, + // ] { + // if let Some(v) = witness.try_get_target(t) { + // println!("a {}", v); + // } + // } + // for t in [ + // Target::VirtualTarget { index: 57952 }, + // Target::VirtualTarget { index: 57953 }, + // Target::VirtualTarget { index: 57954 }, + // Target::VirtualTarget { index: 57955 }, + // Target::VirtualTarget { index: 57956 }, + // Target::VirtualTarget { index: 57957 }, + // Target::VirtualTarget { index: 57958 }, + // Target::VirtualTarget { index: 57959 }, + // ] { + // if let Some(v) = witness.try_get_target(t) { + // println!("b {}", v); + // } + // } + // + // let t = Target::Wire(Wire { + // gate: 141_857, + // input: 8, + // }); + // if let Some(v) = witness.try_get_target(t) { + // println!("prod_exp {}", v); + // } + // let t = Target::Wire(Wire { + // gate: 141_863, + // input: 22, + // }); + // if let Some(v) = witness.try_get_target(t) { + // println!("prod act {}", v); + // } + // let t = Target::Wire(Wire { gate: 9, input: 3 }); + // if let Some(v) = witness.try_get_target(t) { + // println!("modulus {}", v); + // } + // let t = Target::VirtualTarget { index: 57_976 }; + // if let Some(v) = witness.try_get_target(t) { + // println!("overflow {}", v); + // } + // let t = Target::Wire(Wire { + // gate: 141_885, + // input: 8, + // }); + // if let Some(v) = witness.try_get_target(t) { + // println!("mod time ov {}", v); + // } + // let t = Target::VirtualTarget { index: 57_968 }; + // if let Some(v) = witness.try_get_target(t) { + // println!("prod {}", v); + // } } assert_eq!(