From 2f4da9b49d61f1e8dc650f3a858e44bc92574472 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 3 Feb 2022 10:52:44 -0800 Subject: [PATCH 01/56] added native GLV compose --- plonky2/src/curve/glv.rs | 62 ++++++++++++++++++++++++++++++++++++++++ plonky2/src/curve/mod.rs | 1 + 2 files changed, 63 insertions(+) create mode 100644 plonky2/src/curve/glv.rs diff --git a/plonky2/src/curve/glv.rs b/plonky2/src/curve/glv.rs new file mode 100644 index 00000000..6777810e --- /dev/null +++ b/plonky2/src/curve/glv.rs @@ -0,0 +1,62 @@ +use num::rational::Ratio; +use plonky2_field::field_types::Field; +use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + +const BETA: Secp256K1Scalar = Secp256K1Scalar([ + 13923278643952681454, + 11308619431505398165, + 7954561588662645993, + 8856726876819556112, +]); + +const S: Secp256K1Scalar = Secp256K1Scalar([ + 16069571880186789234, + 1310022930574435960, + 11900229862571533402, + 6008836872998760672, +]); + +const A1: Secp256K1Scalar = Secp256K1Scalar([16747920425669159701, 3496713202691238861, 0, 0]); + +const MINUS_B1: Secp256K1Scalar = + Secp256K1Scalar([8022177200260244675, 16448129721693014056, 0, 0]); + +const A2: Secp256K1Scalar = Secp256K1Scalar([6323353552219852760, 1498098850674701302, 1, 0]); + +const B2: Secp256K1Scalar = Secp256K1Scalar([16747920425669159701, 3496713202691238861, 0, 0]); + +pub fn decompose_secp256k1_scalar(k: Secp256K1Scalar) -> (Secp256K1Scalar, Secp256K1Scalar) { + let p = Secp256K1Scalar::order(); + let c1_biguint = Ratio::new(B2.to_biguint() * k.to_biguint(), p.clone()) + .round() + .to_integer(); + let c1 = Secp256K1Scalar::from_biguint(c1_biguint); + let c2_biguint = Ratio::new(MINUS_B1.to_biguint() * k.to_biguint(), p.clone()) + .round() + .to_integer(); + let c2 = Secp256K1Scalar::from_biguint(c2_biguint); + + let k1 = k - c1 * A1 - c2 * A2; + let k2 = c1 * MINUS_B1 - c2 * B2; + debug_assert!(k1 + S * k2 == k); + (k1, k2) +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::glv::{decompose_secp256k1_scalar, S}; + + #[test] + fn test_glv_decompose() -> Result<()> { + let k = Secp256K1Scalar::rand(); + let (k1, k2) = decompose_secp256k1_scalar(k); + + assert!(k1 + S * k2 == k); + + Ok(()) + } +} diff --git a/plonky2/src/curve/mod.rs b/plonky2/src/curve/mod.rs index 8dd6f0d6..1984b0c6 100644 --- a/plonky2/src/curve/mod.rs +++ b/plonky2/src/curve/mod.rs @@ -4,4 +4,5 @@ pub mod curve_multiplication; pub mod curve_summation; pub mod curve_types; pub mod ecdsa; +pub mod glv; pub mod secp256k1; From fd7abb35da88d1e1e1ea980ac17b74dc58c1f555 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 3 Feb 2022 11:02:37 -0800 Subject: [PATCH 02/56] GLV mul --- plonky2/src/curve/glv.rs | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/plonky2/src/curve/glv.rs b/plonky2/src/curve/glv.rs index 6777810e..1001e013 100644 --- a/plonky2/src/curve/glv.rs +++ b/plonky2/src/curve/glv.rs @@ -1,8 +1,9 @@ use num::rational::Ratio; use plonky2_field::field_types::Field; +use plonky2_field::secp256k1_base::Secp256K1Base; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; -const BETA: Secp256K1Scalar = Secp256K1Scalar([ +pub const BETA: Secp256K1Base = Secp256K1Base([ 13923278643952681454, 11308619431505398165, 7954561588662645993, @@ -48,7 +49,10 @@ mod tests { use plonky2_field::field_types::Field; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - use crate::curve::glv::{decompose_secp256k1_scalar, S}; + use crate::curve::curve_msm::msm_parallel; + use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; + use crate::curve::glv::{decompose_secp256k1_scalar, BETA, S}; + use crate::curve::secp256k1::Secp256K1; #[test] fn test_glv_decompose() -> Result<()> { @@ -59,4 +63,29 @@ mod tests { Ok(()) } + + #[test] + fn test_glv_mul() -> Result<()> { + for _ in 0..20 { + let k = Secp256K1Scalar::rand(); + let (k1, k2) = decompose_secp256k1_scalar(k); + + assert!(k1 + S * k2 == k); + + let p = (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE) + .to_affine(); + let sp = AffinePoint:: { + x: p.x * BETA, + y: p.y, + zero: p.zero, + }; + + let kp = CurveScalar(k) * p.to_projective(); + let glv = msm_parallel(&[k1, k2], &[p.to_projective(), sp.to_projective()], 5); + + assert!(kp == glv); + } + + Ok(()) + } } From c279c779a3e8033f8e40728bf262f0cfd63ef847 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 3 Feb 2022 11:04:09 -0800 Subject: [PATCH 03/56] fixed clippy --- plonky2/src/curve/glv.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plonky2/src/curve/glv.rs b/plonky2/src/curve/glv.rs index 1001e013..e4f84886 100644 --- a/plonky2/src/curve/glv.rs +++ b/plonky2/src/curve/glv.rs @@ -32,7 +32,7 @@ pub fn decompose_secp256k1_scalar(k: Secp256K1Scalar) -> (Secp256K1Scalar, Secp2 .round() .to_integer(); let c1 = Secp256K1Scalar::from_biguint(c1_biguint); - let c2_biguint = Ratio::new(MINUS_B1.to_biguint() * k.to_biguint(), p.clone()) + let c2_biguint = Ratio::new(MINUS_B1.to_biguint() * k.to_biguint(), p) .round() .to_integer(); let c2 = Secp256K1Scalar::from_biguint(c2_biguint); From c3126796c0d93616846dd83268da5140ff29e4e9 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 3 Feb 2022 16:24:54 -0800 Subject: [PATCH 04/56] GLV in circuit --- plonky2/src/gadgets/glv.rs | 89 ++++++++++++++++++++++++++++++++++++++ plonky2/src/gadgets/mod.rs | 1 + 2 files changed, 90 insertions(+) create mode 100644 plonky2/src/gadgets/glv.rs diff --git a/plonky2/src/gadgets/glv.rs b/plonky2/src/gadgets/glv.rs new file mode 100644 index 00000000..c5790640 --- /dev/null +++ b/plonky2/src/gadgets/glv.rs @@ -0,0 +1,89 @@ +use std::marker::PhantomData; + +use plonky2_field::extension_field::Extendable; +use plonky2_field::secp256k1_base::Secp256K1Base; +use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + +use crate::curve::glv::{decompose_secp256k1_scalar, BETA}; +use crate::curve::secp256k1::Secp256K1; +use crate::gadgets::curve::AffinePointTarget; +use crate::gadgets::nonnative::NonNativeTarget; +use crate::hash::hash_types::RichField; +use crate::iop::generator::{GeneratedValues, SimpleGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; + +impl, const D: usize> CircuitBuilder { + pub fn secp256k1_glv_beta(&mut self) -> NonNativeTarget { + self.constant_nonnative(BETA) + } + + pub fn decompose_secp256k1_scalar( + &mut self, + k: &NonNativeTarget, + ) -> ( + NonNativeTarget, + NonNativeTarget, + ) { + let k1 = self.add_virtual_nonnative_target::(); + let k2 = self.add_virtual_nonnative_target::(); + + self.add_simple_generator(GLVDecompositionGenerator:: { + k: k.clone(), + k1: k1.clone(), + k2: k2.clone(), + _phantom: PhantomData, + }); + + (k1, k2) + } + + pub fn glv_mul( + &mut self, + k: &NonNativeTarget, + p: &AffinePointTarget, + ) -> AffinePointTarget { + let (k1, k2) = self.decompose_secp256k1_scalar(k); + + let beta = self.secp256k1_glv_beta(); + let beta_px = self.mul_nonnative(&beta, &p.x); + let sp = AffinePointTarget:: { + x: beta_px, + y: p.y.clone(), + }; + + // TODO: replace with MSM + let part1 = self.curve_scalar_mul(&p, &k1); + let part2 = self.curve_scalar_mul(&sp, &k2); + + self.curve_add(&part1, &part2) + } +} + +#[derive(Debug)] +struct GLVDecompositionGenerator, const D: usize> { + k: NonNativeTarget, + k1: NonNativeTarget, + k2: NonNativeTarget, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for GLVDecompositionGenerator +{ + fn dependencies(&self) -> Vec { + self.k.value.limbs.iter().map(|l| l.0).collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let k = witness.get_nonnative_target(self.k.clone()); + let (k1, k2) = decompose_secp256k1_scalar(k); + + out_buffer.set_nonnative_target(self.k1.clone(), k1); + out_buffer.set_nonnative_target(self.k2.clone(), k2); + } +} + +#[cfg(test)] +mod tests {} diff --git a/plonky2/src/gadgets/mod.rs b/plonky2/src/gadgets/mod.rs index ec4d1263..c3dd4a54 100644 --- a/plonky2/src/gadgets/mod.rs +++ b/plonky2/src/gadgets/mod.rs @@ -4,6 +4,7 @@ pub mod arithmetic_u32; pub mod biguint; pub mod curve; pub mod ecdsa; +pub mod glv; pub mod hash; pub mod interpolation; pub mod multiple_comparison; From e92d4c25beab3cd89cd23cd1947c3d7241a9e04e Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 3 Feb 2022 16:35:10 -0800 Subject: [PATCH 05/56] fixed clippy --- plonky2/src/gadgets/glv.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plonky2/src/gadgets/glv.rs b/plonky2/src/gadgets/glv.rs index c5790640..9fdcdaa2 100644 --- a/plonky2/src/gadgets/glv.rs +++ b/plonky2/src/gadgets/glv.rs @@ -54,7 +54,7 @@ impl, const D: usize> CircuitBuilder { }; // TODO: replace with MSM - let part1 = self.curve_scalar_mul(&p, &k1); + let part1 = self.curve_scalar_mul(p, &k1); let part2 = self.curve_scalar_mul(&sp, &k2); self.curve_add(&part1, &part2) From 5917a09cee76a4479d120a84f6457de4825f105c Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 4 Feb 2022 08:36:40 -0800 Subject: [PATCH 06/56] split out glv_mul function --- plonky2/src/curve/glv.rs | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/plonky2/src/curve/glv.rs b/plonky2/src/curve/glv.rs index e4f84886..42444f48 100644 --- a/plonky2/src/curve/glv.rs +++ b/plonky2/src/curve/glv.rs @@ -3,6 +3,10 @@ use plonky2_field::field_types::Field; use plonky2_field::secp256k1_base::Secp256K1Base; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; +use crate::curve::curve_msm::msm_parallel; +use crate::curve::curve_types::{ProjectivePoint, AffinePoint}; +use crate::curve::secp256k1::Secp256K1; + pub const BETA: Secp256K1Base = Secp256K1Base([ 13923278643952681454, 11308619431505398165, @@ -43,15 +47,28 @@ pub fn decompose_secp256k1_scalar(k: Secp256K1Scalar) -> (Secp256K1Scalar, Secp2 (k1, k2) } +pub fn glv_mul(p: ProjectivePoint, k: Secp256K1Scalar) -> ProjectivePoint { + let (k1, k2) = decompose_secp256k1_scalar(k); + assert!(k1 + S * k2 == k); + + let p_affine = p.to_affine(); + let sp = AffinePoint:: { + x: p_affine.x * BETA, + y: p_affine.y, + zero: p_affine.zero, + }; + + msm_parallel(&[k1, k2], &[p, sp.to_projective()], 5) +} + #[cfg(test)] mod tests { use anyhow::Result; use plonky2_field::field_types::Field; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - use crate::curve::curve_msm::msm_parallel; - use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; - use crate::curve::glv::{decompose_secp256k1_scalar, BETA, S}; + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::glv::{decompose_secp256k1_scalar, glv_mul, S}; use crate::curve::secp256k1::Secp256K1; #[test] @@ -68,20 +85,11 @@ mod tests { fn test_glv_mul() -> Result<()> { for _ in 0..20 { let k = Secp256K1Scalar::rand(); - let (k1, k2) = decompose_secp256k1_scalar(k); - assert!(k1 + S * k2 == k); + let p = CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE; - let p = (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE) - .to_affine(); - let sp = AffinePoint:: { - x: p.x * BETA, - y: p.y, - zero: p.zero, - }; - - let kp = CurveScalar(k) * p.to_projective(); - let glv = msm_parallel(&[k1, k2], &[p.to_projective(), sp.to_projective()], 5); + let kp = CurveScalar(k) * p; + let glv = glv_mul(p, k); assert!(kp == glv); } From 5aaa5710a83ed4275b0e0480f07d9445e80cab42 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 4 Feb 2022 08:38:03 -0800 Subject: [PATCH 07/56] test for GLV gadget --- plonky2/src/gadgets/glv.rs | 44 ++++++++++++++++++++++++++++++++++++-- plonky2/src/gadgets/mod.rs | 1 + 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/plonky2/src/gadgets/glv.rs b/plonky2/src/gadgets/glv.rs index 9fdcdaa2..032df761 100644 --- a/plonky2/src/gadgets/glv.rs +++ b/plonky2/src/gadgets/glv.rs @@ -41,8 +41,8 @@ impl, const D: usize> CircuitBuilder { pub fn glv_mul( &mut self, - k: &NonNativeTarget, p: &AffinePointTarget, + k: &NonNativeTarget, ) -> AffinePointTarget { let (k1, k2) = self.decompose_secp256k1_scalar(k); @@ -86,4 +86,44 @@ impl, const D: usize> SimpleGenerator } #[cfg(test)] -mod tests {} +mod tests { + use anyhow::Result; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::secp256k1::Secp256K1; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{PoseidonGoldilocksConfig, GenericConfig}; + use crate::plonk::verifier::verify; + + #[test] + fn test_glv() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let rando = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let randot = builder.constant_affine_point(rando); + + let scalar = Secp256K1Scalar::rand(); + let scalar_target = builder.constant_nonnative(scalar); + + let randot_times_scalar = builder.curve_scalar_mul(&randot, &scalar_target); + let randot_glv_scalar = builder.glv_mul(&randot, &scalar_target); + builder.connect_affine_point(&randot_times_scalar, &randot_glv_scalar); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } +} diff --git a/plonky2/src/gadgets/mod.rs b/plonky2/src/gadgets/mod.rs index c3dd4a54..9d6f7686 100644 --- a/plonky2/src/gadgets/mod.rs +++ b/plonky2/src/gadgets/mod.rs @@ -3,6 +3,7 @@ pub mod arithmetic_extension; pub mod arithmetic_u32; pub mod biguint; pub mod curve; +// pub mod curve_msm; pub mod ecdsa; pub mod glv; pub mod hash; From 140f0590bc8df20840b25dcbf0945954ffbd15cc Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 4 Feb 2022 08:38:21 -0800 Subject: [PATCH 08/56] fmt --- plonky2/src/curve/glv.rs | 2 +- plonky2/src/gadgets/glv.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/plonky2/src/curve/glv.rs b/plonky2/src/curve/glv.rs index 42444f48..ded0d0b3 100644 --- a/plonky2/src/curve/glv.rs +++ b/plonky2/src/curve/glv.rs @@ -4,7 +4,7 @@ use plonky2_field::secp256k1_base::Secp256K1Base; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; use crate::curve::curve_msm::msm_parallel; -use crate::curve::curve_types::{ProjectivePoint, AffinePoint}; +use crate::curve::curve_types::{AffinePoint, ProjectivePoint}; use crate::curve::secp256k1::Secp256K1; pub const BETA: Secp256K1Base = Secp256K1Base([ diff --git a/plonky2/src/gadgets/glv.rs b/plonky2/src/gadgets/glv.rs index 032df761..a75ead7e 100644 --- a/plonky2/src/gadgets/glv.rs +++ b/plonky2/src/gadgets/glv.rs @@ -96,7 +96,7 @@ mod tests { use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{PoseidonGoldilocksConfig, GenericConfig}; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::verifier::verify; #[test] From dc44baa592c4f7b5659123cf7f8ad5722ef9be05 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 4 Feb 2022 09:27:02 -0800 Subject: [PATCH 09/56] simpler test --- plonky2/src/gadgets/curve.rs | 1 + plonky2/src/gadgets/glv.rs | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index 8c182345..9183a07a 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -329,6 +329,7 @@ mod tests { } #[test] + #[ignore] fn test_curve_random() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; diff --git a/plonky2/src/gadgets/glv.rs b/plonky2/src/gadgets/glv.rs index a75ead7e..0a4afb3c 100644 --- a/plonky2/src/gadgets/glv.rs +++ b/plonky2/src/gadgets/glv.rs @@ -92,6 +92,7 @@ mod tests { use plonky2_field::secp256k1_scalar::Secp256K1Scalar; use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::glv::glv_mul; use crate::curve::secp256k1::Secp256K1; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; @@ -117,9 +118,10 @@ mod tests { let scalar = Secp256K1Scalar::rand(); let scalar_target = builder.constant_nonnative(scalar); - let randot_times_scalar = builder.curve_scalar_mul(&randot, &scalar_target); - let randot_glv_scalar = builder.glv_mul(&randot, &scalar_target); - builder.connect_affine_point(&randot_times_scalar, &randot_glv_scalar); + let rando_glv_scalar = glv_mul(rando.to_projective(), scalar); + let expected = builder.constant_affine_point(rando_glv_scalar.to_affine()); + let actual = builder.glv_mul(&randot, &scalar_target); + builder.connect_affine_point(&expected, &actual); let data = builder.build::(); let proof = data.prove(pw).unwrap(); From 53a2a92258aec1fbb1abfe8a5594937f3bd852cd Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 16 Feb 2022 11:30:32 -0800 Subject: [PATCH 10/56] windowed multiplication in circuit --- plonky2/src/fri/recursive_verifier.rs | 3 +- plonky2/src/gadgets/curve.rs | 83 +++++ plonky2/src/gadgets/mod.rs | 1 + plonky2/src/gadgets/random_access.rs | 26 +- plonky2/src/gadgets/split_nonnative.rs | 34 ++ plonky2/src/hash/merkle_proofs.rs | 12 +- plonky2/src/plonk/circuit_builder.rs | 440 +++++++++++++++++++++++++ 7 files changed, 579 insertions(+), 20 deletions(-) create mode 100644 plonky2/src/gadgets/split_nonnative.rs diff --git a/plonky2/src/fri/recursive_verifier.rs b/plonky2/src/fri/recursive_verifier.rs index 9b619ea8..f56e19ca 100644 --- a/plonky2/src/fri/recursive_verifier.rs +++ b/plonky2/src/fri/recursive_verifier.rs @@ -332,7 +332,8 @@ impl, const D: usize> CircuitBuilder { let x_index_within_coset = self.le_sum(x_index_within_coset_bits.iter()); // Check consistency with our old evaluation from the previous round. - self.random_access_extension(x_index_within_coset, old_eval, evals.clone()); + let new_eval = self.random_access_extension(x_index_within_coset, evals.clone()); + self.connect_extension(new_eval, old_eval); // Infer P(y) from {P(x)}_{x^arity=y}. old_eval = with_context!( diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index 9183a07a..c2af0104 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -1,11 +1,18 @@ +use std::marker::PhantomData; + use plonky2_field::extension_field::Extendable; use plonky2_field::field_types::Field; use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; +use crate::gadgets::arithmetic_u32::U32Target; +use crate::gadgets::biguint::BigUintTarget; use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::RichField; +use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; +const WINDOW_SIZE: usize = 4; + /// A Target representing an affine point on the curve `C`. We use incomplete arithmetic for efficiency, /// so we assume these points are not zero. #[derive(Clone, Debug)] @@ -157,6 +164,82 @@ impl, const D: usize> CircuitBuilder { result } + + pub fn precompute_window( + &mut self, + p: &AffinePointTarget, + ) -> Vec> { + let mut multiples = Vec::new(); + multiples.push(self.constant_affine_point(C::GENERATOR_AFFINE)); + let mut cur = p.clone(); + for _pow in 1..WINDOW_SIZE { + for existing in multiples.clone() { + multiples.push(self.curve_add(&cur, &existing)); + } + cur = self.curve_double(&cur); + } + + multiples + } + + pub fn random_access_curve_points( + &mut self, + access_index: Target, + v: Vec>, + ) -> AffinePointTarget { + let num_limbs = v[0].x.value.num_limbs(); + let x_limbs: Vec> = (0..num_limbs) + .map(|i| v.iter().map(|p| p.x.value.limbs[i].0).collect()) + .collect(); + let y_limbs: Vec> = (0..num_limbs) + .map(|i| v.iter().map(|p| p.y.value.limbs[i].0).collect()) + .collect(); + + let selected_x_limbs: Vec<_> = x_limbs + .iter() + .map(|limbs| U32Target(self.random_access(access_index, limbs.clone()))) + .collect(); + let selected_y_limbs: Vec<_> = y_limbs + .iter() + .map(|limbs| U32Target(self.random_access(access_index, limbs.clone()))) + .collect(); + + let x = NonNativeTarget { + value: BigUintTarget { + limbs: selected_x_limbs, + }, + _phantom: PhantomData, + }; + let y = NonNativeTarget { + value: BigUintTarget { + limbs: selected_y_limbs, + }, + _phantom: PhantomData, + }; + AffinePointTarget { x, y } + } + + pub fn curve_scalar_mul_windowed( + &mut self, + p: &AffinePointTarget, + n: &NonNativeTarget, + ) -> AffinePointTarget { + let mut result = self.constant_affine_point(C::GENERATOR_AFFINE); + + let precomputation = self.precompute_window(p); + + let windows = self.split_nonnative_to_4_bit_limbs(n); + let m = C::ScalarField::BITS / WINDOW_SIZE; + for i in m..0 { + result = self.curve_double(&result); + let window = windows[i]; + + let to_add = self.random_access_curve_points(window, precomputation.clone()); + result = self.curve_add(&result, &to_add); + } + + result + } } #[cfg(test)] diff --git a/plonky2/src/gadgets/mod.rs b/plonky2/src/gadgets/mod.rs index 9d6f7686..95e46b42 100644 --- a/plonky2/src/gadgets/mod.rs +++ b/plonky2/src/gadgets/mod.rs @@ -16,3 +16,4 @@ pub mod range_check; pub mod select; pub mod split_base; pub(crate) mod split_join; +pub mod split_nonnative; diff --git a/plonky2/src/gadgets/random_access.rs b/plonky2/src/gadgets/random_access.rs index 9518e9fa..ec3c889a 100644 --- a/plonky2/src/gadgets/random_access.rs +++ b/plonky2/src/gadgets/random_access.rs @@ -10,13 +10,15 @@ use crate::plonk::circuit_builder::CircuitBuilder; impl, const D: usize> CircuitBuilder { /// Checks that a `Target` matches a vector at a non-deterministic index. /// Note: `access_index` is not range-checked. - pub fn random_access(&mut self, access_index: Target, claimed_element: Target, v: Vec) { + pub fn random_access(&mut self, access_index: Target, v: Vec) -> Target { let vec_size = v.len(); let bits = log2_strict(vec_size); debug_assert!(vec_size > 0); if vec_size == 1 { - return self.connect(claimed_element, v[0]); + return v[0]; } + let claimed_element = self.add_virtual_target(); + let dummy_gate = RandomAccessGate::::new_from_config(&self.config, bits); let (gate_index, copy) = self.find_slot(dummy_gate, &[], &[]); @@ -34,6 +36,8 @@ impl, const D: usize> CircuitBuilder { claimed_element, Target::wire(gate_index, dummy_gate.wire_claimed_element(copy)), ); + + claimed_element } /// Checks that an `ExtensionTarget` matches a vector at a non-deterministic index. @@ -41,16 +45,13 @@ impl, const D: usize> CircuitBuilder { pub fn random_access_extension( &mut self, access_index: Target, - claimed_element: ExtensionTarget, v: Vec>, - ) { - for i in 0..D { - self.random_access( - access_index, - claimed_element.0[i], - v.iter().map(|et| et.0[i]).collect(), - ); - } + ) -> ExtensionTarget { + let v: Vec<_> = (0..D) + .map(|i| self.random_access(access_index, v.iter().map(|et| et.0[i]).collect())) + .collect(); + + ExtensionTarget(v.try_into().unwrap()) } } @@ -80,7 +81,8 @@ mod tests { for i in 0..len { let it = builder.constant(F::from_canonical_usize(i)); let elem = builder.constant_extension(vec[i]); - builder.random_access_extension(it, elem, v.clone()); + let res = builder.random_access_extension(it, v.clone()); + builder.connect_extension(elem, res); } let data = builder.build::(); diff --git a/plonky2/src/gadgets/split_nonnative.rs b/plonky2/src/gadgets/split_nonnative.rs new file mode 100644 index 00000000..d1f16b65 --- /dev/null +++ b/plonky2/src/gadgets/split_nonnative.rs @@ -0,0 +1,34 @@ +use itertools::Itertools; +use plonky2_field::extension_field::Extendable; +use plonky2_field::field_types::Field; + +use crate::gadgets::arithmetic_u32::U32Target; +use crate::gadgets::nonnative::NonNativeTarget; +use crate::hash::hash_types::RichField; +use crate::iop::target::Target; +use crate::plonk::circuit_builder::CircuitBuilder; + +impl, const D: usize> CircuitBuilder { + pub fn split_u32_to_4_bit_limbs(&mut self, val: U32Target) -> Vec { + let two_bit_limbs = self.split_le_base::<2>(val.0, 16); + let four = self.constant(F::from_canonical_usize(4)); + let combined_limbs = two_bit_limbs + .iter() + .tuples() + .map(|(&a, &b)| self.mul_add(b, four, a)) + .collect(); + + combined_limbs + } + + pub fn split_nonnative_to_4_bit_limbs( + &mut self, + val: &NonNativeTarget, + ) -> Vec { + val.value + .limbs + .iter() + .flat_map(|&l| self.split_u32_to_4_bit_limbs(l)) + .collect() + } +} diff --git a/plonky2/src/hash/merkle_proofs.rs b/plonky2/src/hash/merkle_proofs.rs index c3ebf406..c4188271 100644 --- a/plonky2/src/hash/merkle_proofs.rs +++ b/plonky2/src/hash/merkle_proofs.rs @@ -78,11 +78,9 @@ impl, const D: usize> CircuitBuilder { let index = self.le_sum(leaf_index_bits[proof.siblings.len()..].iter().copied()); for i in 0..4 { - self.random_access( - index, - state.elements[i], - merkle_cap.0.iter().map(|h| h.elements[i]).collect(), - ); + let result = + self.random_access(index, merkle_cap.0.iter().map(|h| h.elements[i]).collect()); + self.connect(result, state.elements[i]); } } @@ -110,11 +108,11 @@ impl, const D: usize> CircuitBuilder { } for i in 0..4 { - self.random_access( + let result = self.random_access( cap_index, - state.elements[i], merkle_cap.0.iter().map(|h| h.elements[i]).collect(), ); + self.connect(result, state.elements[i]); } } diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index bd216389..ff975659 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -846,3 +846,443 @@ impl, const D: usize> CircuitBuilder { } } } +<<<<<<< HEAD +======= + +/// Various gate types can contain multiple copies in a single Gate. This helper struct lets a +/// CircuitBuilder track such gates that are currently being "filled up." +pub struct BatchedGates, const D: usize> { + /// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using + /// these constants with gate index `g` and already using `i` arithmetic operations. + pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, + pub(crate) free_base_arithmetic: HashMap<(F, F), (usize, usize)>, + + pub(crate) free_mul: HashMap, + + /// A map `b -> (g, i)` from `b` bits to an available random access gate of that size with gate + /// index `g` and already using `i` random accesses. + pub(crate) free_random_access: HashMap, + + /// `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value + /// chunk_size, and contains `(g, i, c)`, if the gate `g`, at index `i`, already contains `c` copies + /// of switches + pub(crate) current_switch_gates: Vec, usize, usize)>>, + + /// A map `n -> (g, i)` from `n` number of addends to an available `U32AddManyGate` of that size with gate + /// index `g` and already using `i` random accesses. + pub(crate) free_u32_add_many: HashMap, + + /// The `U32ArithmeticGate` currently being filled (so new u32 arithmetic operations will be added to this gate before creating a new one) + pub(crate) current_u32_arithmetic_gate: Option<(usize, usize)>, + /// The `U32SubtractionGate` currently being filled (so new u32 subtraction operations will be added to this gate before creating a new one) + pub(crate) current_u32_subtraction_gate: Option<(usize, usize)>, + + /// An available `ConstantGate` instance, if any. + pub(crate) free_constant: Option<(usize, usize)>, +} + +impl, const D: usize> BatchedGates { + pub fn new() -> Self { + Self { + free_arithmetic: HashMap::new(), + free_base_arithmetic: HashMap::new(), + free_mul: HashMap::new(), + free_random_access: HashMap::new(), + current_switch_gates: Vec::new(), + free_u32_add_many: HashMap::new(), + current_u32_arithmetic_gate: None, + current_u32_subtraction_gate: None, + free_constant: None, + } + } +} + +impl, const D: usize> CircuitBuilder { + /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. + /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index + /// `g` and the gate's `i`-th operation is available. + pub(crate) fn find_base_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_base_arithmetic + .get(&(const_0, const_1)) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + ArithmeticGate::new_from_config(&self.config), + vec![const_0, const_1], + ); + (gate, 0) + }); + + // Update `free_arithmetic` with new values. + if i < ArithmeticGate::num_ops(&self.config) - 1 { + self.batched_gates + .free_base_arithmetic + .insert((const_0, const_1), (gate, i + 1)); + } else { + self.batched_gates + .free_base_arithmetic + .remove(&(const_0, const_1)); + } + + (gate, i) + } + + /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. + /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index + /// `g` and the gate's `i`-th operation is available. + pub(crate) fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_arithmetic + .get(&(const_0, const_1)) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + ArithmeticExtensionGate::new_from_config(&self.config), + vec![const_0, const_1], + ); + (gate, 0) + }); + + // Update `free_arithmetic` with new values. + if i < ArithmeticExtensionGate::::num_ops(&self.config) - 1 { + self.batched_gates + .free_arithmetic + .insert((const_0, const_1), (gate, i + 1)); + } else { + self.batched_gates + .free_arithmetic + .remove(&(const_0, const_1)); + } + + (gate, i) + } + + /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. + /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index + /// `g` and the gate's `i`-th operation is available. + pub(crate) fn find_mul_gate(&mut self, const_0: F) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_mul + .get(&const_0) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + MulExtensionGate::new_from_config(&self.config), + vec![const_0], + ); + (gate, 0) + }); + + // Update `free_arithmetic` with new values. + if i < MulExtensionGate::::num_ops(&self.config) - 1 { + self.batched_gates.free_mul.insert(const_0, (gate, i + 1)); + } else { + self.batched_gates.free_mul.remove(&const_0); + } + + (gate, i) + } + + /// Finds the last available random access gate with the given `bits` or adds one if there aren't any. + /// Returns `(g,i)` such that there is a random access gate for the given `bits` at index + /// `g` and the gate's `i`-th random access is available. + pub(crate) fn find_random_access_gate(&mut self, bits: usize) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_random_access + .get(&bits) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + RandomAccessGate::new_from_config(&self.config, bits), + vec![], + ); + (gate, 0) + }); + + // Update `free_random_access` with new values. + if i + 1 < RandomAccessGate::::new_from_config(&self.config, bits).num_copies { + self.batched_gates + .free_random_access + .insert(bits, (gate, i + 1)); + } else { + self.batched_gates.free_random_access.remove(&bits); + } + + (gate, i) + } + + pub fn find_switch_gate(&mut self, chunk_size: usize) -> (SwitchGate, usize, usize) { + if self.batched_gates.current_switch_gates.len() < chunk_size { + self.batched_gates.current_switch_gates.extend(vec![ + None; + chunk_size + - self + .batched_gates + .current_switch_gates + .len() + ]); + } + + let (gate, gate_index, next_copy) = + match self.batched_gates.current_switch_gates[chunk_size - 1].clone() { + None => { + let gate = SwitchGate::::new_from_config(&self.config, chunk_size); + let gate_index = self.add_gate(gate.clone(), vec![]); + (gate, gate_index, 0) + } + Some((gate, idx, next_copy)) => (gate, idx, next_copy), + }; + + let num_copies = gate.num_copies; + + if next_copy == num_copies - 1 { + self.batched_gates.current_switch_gates[chunk_size - 1] = None; + } else { + self.batched_gates.current_switch_gates[chunk_size - 1] = + Some((gate.clone(), gate_index, next_copy + 1)); + } + + (gate, gate_index, next_copy) + } + + /// Finds the last available U32 add-many gate with the given `num_addends` or adds one if there aren't any. + /// Returns `(g,i)` such that there is a `U32AddManyGate` for the given `num_addends` at index + /// `g` and the gate's `i`-th copy is available. + pub(crate) fn find_u32_add_many_gate(&mut self, num_addends: usize) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_u32_add_many + .get(&num_addends) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + U32AddManyGate::new_from_config(&self.config, num_addends), + vec![], + ); + (gate, 0) + }); + + // Update `free_u32_add_many` with new values. + if i + 1 < U32AddManyGate::::new_from_config(&self.config, num_addends).num_ops { + self.batched_gates + .free_u32_add_many + .insert(num_addends, (gate, i + 1)); + } else { + self.batched_gates.free_u32_add_many.remove(&num_addends); + } + + (gate, i) + } + + pub(crate) fn find_u32_arithmetic_gate(&mut self) -> (usize, usize) { + let (gate_index, copy) = match self.batched_gates.current_u32_arithmetic_gate { + None => { + let gate = U32ArithmeticGate::new_from_config(&self.config); + let gate_index = self.add_gate(gate, vec![]); + (gate_index, 0) + } + Some((gate_index, copy)) => (gate_index, copy), + }; + + if copy == U32ArithmeticGate::::num_ops(&self.config) - 1 { + self.batched_gates.current_u32_arithmetic_gate = None; + } else { + self.batched_gates.current_u32_arithmetic_gate = Some((gate_index, copy + 1)); + } + + (gate_index, copy) + } + + pub(crate) fn find_u32_subtraction_gate(&mut self) -> (usize, usize) { + let (gate_index, copy) = match self.batched_gates.current_u32_subtraction_gate { + None => { + let gate = U32SubtractionGate::new_from_config(&self.config); + let gate_index = self.add_gate(gate, vec![]); + (gate_index, 0) + } + Some((gate_index, copy)) => (gate_index, copy), + }; + + if copy == U32SubtractionGate::::num_ops(&self.config) - 1 { + self.batched_gates.current_u32_subtraction_gate = None; + } else { + self.batched_gates.current_u32_subtraction_gate = Some((gate_index, copy + 1)); + } + + (gate_index, copy) + } + + /// Returns the gate index and copy index of a free `ConstantGate` slot, potentially adding a + /// new `ConstantGate` if needed. + fn constant_gate_instance(&mut self) -> (usize, usize) { + if self.batched_gates.free_constant.is_none() { + let num_consts = self.config.constant_gate_size; + // We will fill this `ConstantGate` with zero constants initially. + // These will be overwritten by `constant` as the gate instances are filled. + let gate = self.add_gate(ConstantGate { num_consts }, vec![F::ZERO; num_consts]); + self.batched_gates.free_constant = Some((gate, 0)); + } + + let (gate, instance) = self.batched_gates.free_constant.unwrap(); + if instance + 1 < self.config.constant_gate_size { + self.batched_gates.free_constant = Some((gate, instance + 1)); + } else { + self.batched_gates.free_constant = None; + } + (gate, instance) + } + + /// Fill the remaining unused arithmetic operations with zeros, so that all + /// `ArithmeticGate` are run. + fn fill_base_arithmetic_gates(&mut self) { + let zero = self.zero(); + for ((c0, c1), (_gate, i)) in self.batched_gates.free_base_arithmetic.clone() { + for _ in i..ArithmeticGate::num_ops(&self.config) { + // If we directly wire in zero, an optimization will skip doing anything and return + // zero. So we pass in a virtual target and connect it to zero afterward. + let dummy = self.add_virtual_target(); + self.arithmetic(c0, c1, dummy, dummy, dummy); + self.connect(dummy, zero); + } + } + assert!(self.batched_gates.free_base_arithmetic.is_empty()); + } + + /// Fill the remaining unused arithmetic operations with zeros, so that all + /// `ArithmeticExtensionGenerator`s are run. + fn fill_arithmetic_gates(&mut self) { + let zero = self.zero_extension(); + for ((c0, c1), (_gate, i)) in self.batched_gates.free_arithmetic.clone() { + for _ in i..ArithmeticExtensionGate::::num_ops(&self.config) { + // If we directly wire in zero, an optimization will skip doing anything and return + // zero. So we pass in a virtual target and connect it to zero afterward. + let dummy = self.add_virtual_extension_target(); + self.arithmetic_extension(c0, c1, dummy, dummy, dummy); + self.connect_extension(dummy, zero); + } + } + assert!(self.batched_gates.free_arithmetic.is_empty()); + } + + /// Fill the remaining unused arithmetic operations with zeros, so that all + /// `ArithmeticExtensionGenerator`s are run. + fn fill_mul_gates(&mut self) { + let zero = self.zero_extension(); + for (c0, (_gate, i)) in self.batched_gates.free_mul.clone() { + for _ in i..MulExtensionGate::::num_ops(&self.config) { + // If we directly wire in zero, an optimization will skip doing anything and return + // zero. So we pass in a virtual target and connect it to zero afterward. + let dummy = self.add_virtual_extension_target(); + self.arithmetic_extension(c0, F::ZERO, dummy, dummy, zero); + self.connect_extension(dummy, zero); + } + } + assert!(self.batched_gates.free_mul.is_empty()); + } + + /// Fill the remaining unused random access operations with zeros, so that all + /// `RandomAccessGenerator`s are run. + fn fill_random_access_gates(&mut self) { + let zero = self.zero(); + for (bits, (_, i)) in self.batched_gates.free_random_access.clone() { + let max_copies = + RandomAccessGate::::new_from_config(&self.config, bits).num_copies; + for _ in i..max_copies { + let result = self.random_access(zero, vec![zero; 1 << bits]); + self.connect(result, zero); + } + } + } + + /// Fill the remaining unused switch gates with dummy values, so that all + /// `SwitchGenerator`s are run. + fn fill_switch_gates(&mut self) { + let zero = self.zero(); + + for chunk_size in 1..=self.batched_gates.current_switch_gates.len() { + if let Some((gate, gate_index, mut copy)) = + self.batched_gates.current_switch_gates[chunk_size - 1].clone() + { + while copy < gate.num_copies { + for element in 0..chunk_size { + let wire_first_input = + Target::wire(gate_index, gate.wire_first_input(copy, element)); + let wire_second_input = + Target::wire(gate_index, gate.wire_second_input(copy, element)); + let wire_switch_bool = + Target::wire(gate_index, gate.wire_switch_bool(copy)); + self.connect(zero, wire_first_input); + self.connect(zero, wire_second_input); + self.connect(zero, wire_switch_bool); + } + copy += 1; + } + } + } + } + + /// Fill the remaining unused u32 add-many operations with zeros, so that all + /// `U32AddManyGenerator`s are run. + fn fill_u32_add_many_gates(&mut self) { + let zero = self.zero_u32(); + for (num_addends, (_, i)) in self.batched_gates.free_u32_add_many.clone() { + let max_copies = + U32AddManyGate::::new_from_config(&self.config, num_addends).num_ops; + for _ in i..max_copies { + let gate = U32AddManyGate::::new_from_config(&self.config, num_addends); + let (gate_index, copy) = self.find_u32_add_many_gate(num_addends); + + for j in 0..num_addends { + self.connect( + Target::wire(gate_index, gate.wire_ith_op_jth_addend(copy, j)), + zero.0, + ); + } + self.connect(Target::wire(gate_index, gate.wire_ith_carry(copy)), zero.0); + } + } + } + + /// Fill the remaining unused U32 arithmetic operations with zeros, so that all + /// `U32ArithmeticGenerator`s are run. + fn fill_u32_arithmetic_gates(&mut self) { + let zero = self.zero_u32(); + if let Some((_gate_index, copy)) = self.batched_gates.current_u32_arithmetic_gate { + for _ in copy..U32ArithmeticGate::::num_ops(&self.config) { + let dummy = self.add_virtual_u32_target(); + self.mul_add_u32(dummy, dummy, dummy); + self.connect_u32(dummy, zero); + } + } + } + + /// Fill the remaining unused U32 subtraction operations with zeros, so that all + /// `U32SubtractionGenerator`s are run. + fn fill_u32_subtraction_gates(&mut self) { + let zero = self.zero_u32(); + if let Some((_gate_index, copy)) = self.batched_gates.current_u32_subtraction_gate { + for _i in copy..U32SubtractionGate::::num_ops(&self.config) { + let dummy = self.add_virtual_u32_target(); + self.sub_u32(dummy, dummy, dummy); + self.connect_u32(dummy, zero); + } + } + } + + fn fill_batched_gates(&mut self) { + self.fill_arithmetic_gates(); + self.fill_base_arithmetic_gates(); + self.fill_mul_gates(); + self.fill_random_access_gates(); + self.fill_switch_gates(); + self.fill_u32_add_many_gates(); + self.fill_u32_arithmetic_gates(); + self.fill_u32_subtraction_gates(); + } +} +>>>>>>> aa48021 (windowed multiplication in circuit) From 67b7193e82e19655187ba39bc2a360db0a128b90 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 10 Feb 2022 11:18:47 -0800 Subject: [PATCH 11/56] test for split nonnative, and fixes --- plonky2/src/gadgets/curve.rs | 34 +++++++++++++- plonky2/src/gadgets/split_nonnative.rs | 65 +++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 2 deletions(-) diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index c2af0104..d56c7650 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -230,7 +230,7 @@ impl, const D: usize> CircuitBuilder { let windows = self.split_nonnative_to_4_bit_limbs(n); let m = C::ScalarField::BITS / WINDOW_SIZE; - for i in m..0 { + for i in (0..m).rev() { result = self.curve_double(&result); let window = windows[i]; @@ -411,6 +411,38 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + #[test] + fn test_curve_mul_windowed() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let five = Secp256K1Scalar::from_canonical_usize(5); + let neg_five = five.neg(); + let neg_five_scalar = CurveScalar::(neg_five); + let neg_five_g = (neg_five_scalar * g.to_projective()).to_affine(); + let neg_five_g_expected = builder.constant_affine_point(neg_five_g); + builder.curve_assert_valid(&neg_five_g_expected); + + let g_target = builder.constant_affine_point(g); + let neg_five_target = builder.constant_nonnative(neg_five); + let neg_five_g_actual = builder.curve_scalar_mul_windowed(&g_target, &neg_five_target); + /*builder.curve_assert_valid(&neg_five_g_actual); + + builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual);*/ + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } + #[test] #[ignore] fn test_curve_random() -> Result<()> { diff --git a/plonky2/src/gadgets/split_nonnative.rs b/plonky2/src/gadgets/split_nonnative.rs index d1f16b65..88c693c0 100644 --- a/plonky2/src/gadgets/split_nonnative.rs +++ b/plonky2/src/gadgets/split_nonnative.rs @@ -1,8 +1,11 @@ +use std::marker::PhantomData; + use itertools::Itertools; use plonky2_field::extension_field::Extendable; use plonky2_field::field_types::Field; use crate::gadgets::arithmetic_u32::U32Target; +use crate::gadgets::biguint::BigUintTarget; use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::RichField; use crate::iop::target::Target; @@ -10,7 +13,7 @@ use crate::plonk::circuit_builder::CircuitBuilder; impl, const D: usize> CircuitBuilder { pub fn split_u32_to_4_bit_limbs(&mut self, val: U32Target) -> Vec { - let two_bit_limbs = self.split_le_base::<2>(val.0, 16); + let two_bit_limbs = self.split_le_base::<4>(val.0, 16); let four = self.constant(F::from_canonical_usize(4)); let combined_limbs = two_bit_limbs .iter() @@ -31,4 +34,64 @@ impl, const D: usize> CircuitBuilder { .flat_map(|&l| self.split_u32_to_4_bit_limbs(l)) .collect() } + + // Note: assumes its inputs are 4-bit limbs, and does not range-check. + pub fn recombine_nonnative_4_bit_limbs( + &mut self, + limbs: Vec, + ) -> NonNativeTarget { + let base = self.constant_u32(1 << 4); + let u32_limbs = limbs + .chunks(8) + .map(|chunk| { + let mut combined_chunk = self.zero_u32(); + for i in (0..8).rev() { + let (low, _high) = self.mul_add_u32(combined_chunk, base, U32Target(chunk[i])); + combined_chunk = low; + } + combined_chunk + }) + .collect(); + + NonNativeTarget { + value: BigUintTarget { limbs: u32_limbs }, + _phantom: PhantomData, + } + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::gadgets::nonnative::NonNativeTarget; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use crate::plonk::verifier::verify; + + #[test] + fn test_split_nonnative() -> Result<()> { + type FF = Secp256K1Scalar; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = FF::rand(); + let x_target = builder.constant_nonnative(x); + let split = builder.split_nonnative_to_4_bit_limbs(&x_target); + let combined: NonNativeTarget = builder.recombine_nonnative_4_bit_limbs(split); + builder.connect_nonnative(&x_target, &combined); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } } From 58492a0aceca413da2338b9e3228bfb7ab318721 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 10 Feb 2022 11:19:02 -0800 Subject: [PATCH 12/56] fmt --- plonky2/src/gadgets/split_nonnative.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plonky2/src/gadgets/split_nonnative.rs b/plonky2/src/gadgets/split_nonnative.rs index 88c693c0..70661506 100644 --- a/plonky2/src/gadgets/split_nonnative.rs +++ b/plonky2/src/gadgets/split_nonnative.rs @@ -87,7 +87,8 @@ mod tests { let x = FF::rand(); let x_target = builder.constant_nonnative(x); let split = builder.split_nonnative_to_4_bit_limbs(&x_target); - let combined: NonNativeTarget = builder.recombine_nonnative_4_bit_limbs(split); + let combined: NonNativeTarget = + builder.recombine_nonnative_4_bit_limbs(split); builder.connect_nonnative(&x_target, &combined); let data = builder.build::(); From 5603816f3baafc97f2953e43284f3a8a8ea61590 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 10 Feb 2022 12:14:30 -0800 Subject: [PATCH 13/56] fix --- plonky2/src/gadgets/curve.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index d56c7650..21981f9b 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -165,6 +165,7 @@ impl, const D: usize> CircuitBuilder { result } + // TODO: fix if p is the generator pub fn precompute_window( &mut self, p: &AffinePointTarget, @@ -172,13 +173,15 @@ impl, const D: usize> CircuitBuilder { let mut multiples = Vec::new(); multiples.push(self.constant_affine_point(C::GENERATOR_AFFINE)); let mut cur = p.clone(); - for _pow in 1..WINDOW_SIZE { + for _pow in 0..WINDOW_SIZE { for existing in multiples.clone() { multiples.push(self.curve_add(&cur, &existing)); } cur = self.curve_double(&cur); } + println!("SIZE OF WINDOW: {}", multiples.len()); + multiples } @@ -422,7 +425,7 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let g = Secp256K1::GENERATOR_AFFINE; + let g = (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); let five = Secp256K1Scalar::from_canonical_usize(5); let neg_five = five.neg(); let neg_five_scalar = CurveScalar::(neg_five); @@ -433,9 +436,9 @@ mod tests { let g_target = builder.constant_affine_point(g); let neg_five_target = builder.constant_nonnative(neg_five); let neg_five_g_actual = builder.curve_scalar_mul_windowed(&g_target, &neg_five_target); - /*builder.curve_assert_valid(&neg_five_g_actual); + builder.curve_assert_valid(&neg_five_g_actual); - builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual);*/ + builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual); let data = builder.build::(); let proof = data.prove(pw).unwrap(); From 294a738dc9232cf8a5f85023bde0307f7c650fb6 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 11 Feb 2022 09:01:02 -0800 Subject: [PATCH 14/56] moved to new file, and curve random access test --- plonky2/src/gadgets/curve.rs | 120 +-------------- plonky2/src/gadgets/curve_windowed_mul.rs | 175 ++++++++++++++++++++++ plonky2/src/gadgets/mod.rs | 1 + 3 files changed, 177 insertions(+), 119 deletions(-) create mode 100644 plonky2/src/gadgets/curve_windowed_mul.rs diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index 21981f9b..908ed66e 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -1,18 +1,11 @@ -use std::marker::PhantomData; - use plonky2_field::extension_field::Extendable; use plonky2_field::field_types::Field; use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; -use crate::gadgets::arithmetic_u32::U32Target; -use crate::gadgets::biguint::BigUintTarget; use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::RichField; -use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; -const WINDOW_SIZE: usize = 4; - /// A Target representing an affine point on the curve `C`. We use incomplete arithmetic for efficiency, /// so we assume these points are not zero. #[derive(Clone, Debug)] @@ -164,85 +157,6 @@ impl, const D: usize> CircuitBuilder { result } - - // TODO: fix if p is the generator - pub fn precompute_window( - &mut self, - p: &AffinePointTarget, - ) -> Vec> { - let mut multiples = Vec::new(); - multiples.push(self.constant_affine_point(C::GENERATOR_AFFINE)); - let mut cur = p.clone(); - for _pow in 0..WINDOW_SIZE { - for existing in multiples.clone() { - multiples.push(self.curve_add(&cur, &existing)); - } - cur = self.curve_double(&cur); - } - - println!("SIZE OF WINDOW: {}", multiples.len()); - - multiples - } - - pub fn random_access_curve_points( - &mut self, - access_index: Target, - v: Vec>, - ) -> AffinePointTarget { - let num_limbs = v[0].x.value.num_limbs(); - let x_limbs: Vec> = (0..num_limbs) - .map(|i| v.iter().map(|p| p.x.value.limbs[i].0).collect()) - .collect(); - let y_limbs: Vec> = (0..num_limbs) - .map(|i| v.iter().map(|p| p.y.value.limbs[i].0).collect()) - .collect(); - - let selected_x_limbs: Vec<_> = x_limbs - .iter() - .map(|limbs| U32Target(self.random_access(access_index, limbs.clone()))) - .collect(); - let selected_y_limbs: Vec<_> = y_limbs - .iter() - .map(|limbs| U32Target(self.random_access(access_index, limbs.clone()))) - .collect(); - - let x = NonNativeTarget { - value: BigUintTarget { - limbs: selected_x_limbs, - }, - _phantom: PhantomData, - }; - let y = NonNativeTarget { - value: BigUintTarget { - limbs: selected_y_limbs, - }, - _phantom: PhantomData, - }; - AffinePointTarget { x, y } - } - - pub fn curve_scalar_mul_windowed( - &mut self, - p: &AffinePointTarget, - n: &NonNativeTarget, - ) -> AffinePointTarget { - let mut result = self.constant_affine_point(C::GENERATOR_AFFINE); - - let precomputation = self.precompute_window(p); - - let windows = self.split_nonnative_to_4_bit_limbs(n); - let m = C::ScalarField::BITS / WINDOW_SIZE; - for i in (0..m).rev() { - result = self.curve_double(&result); - let window = windows[i]; - - let to_add = self.random_access_curve_points(window, precomputation.clone()); - result = self.curve_add(&result, &to_add); - } - - result - } } #[cfg(test)] @@ -393,7 +307,7 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let g = Secp256K1::GENERATOR_AFFINE; + let g = Secp256K1::GENERATOR_PROJECTIVE.to_affine(); let five = Secp256K1Scalar::from_canonical_usize(5); let neg_five = five.neg(); let neg_five_scalar = CurveScalar::(neg_five); @@ -414,38 +328,6 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } - #[test] - fn test_curve_mul_windowed() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let g = (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); - let five = Secp256K1Scalar::from_canonical_usize(5); - let neg_five = five.neg(); - let neg_five_scalar = CurveScalar::(neg_five); - let neg_five_g = (neg_five_scalar * g.to_projective()).to_affine(); - let neg_five_g_expected = builder.constant_affine_point(neg_five_g); - builder.curve_assert_valid(&neg_five_g_expected); - - let g_target = builder.constant_affine_point(g); - let neg_five_target = builder.constant_nonnative(neg_five); - let neg_five_g_actual = builder.curve_scalar_mul_windowed(&g_target, &neg_five_target); - builder.curve_assert_valid(&neg_five_g_actual); - - builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common) - } - #[test] #[ignore] fn test_curve_random() -> Result<()> { diff --git a/plonky2/src/gadgets/curve_windowed_mul.rs b/plonky2/src/gadgets/curve_windowed_mul.rs new file mode 100644 index 00000000..77bc6f3f --- /dev/null +++ b/plonky2/src/gadgets/curve_windowed_mul.rs @@ -0,0 +1,175 @@ +use std::marker::PhantomData; + +use plonky2_field::extension_field::Extendable; +use plonky2_field::field_types::Field; + +use crate::curve::curve_types::Curve; +use crate::gadgets::arithmetic_u32::U32Target; +use crate::gadgets::biguint::BigUintTarget; +use crate::gadgets::curve::AffinePointTarget; +use crate::gadgets::nonnative::NonNativeTarget; +use crate::hash::hash_types::RichField; +use crate::iop::target::Target; +use crate::plonk::circuit_builder::CircuitBuilder; + +const WINDOW_SIZE: usize = 4; + +impl, const D: usize> CircuitBuilder { + // TODO: fix if p is the generator + pub fn precompute_window( + &mut self, + p: &AffinePointTarget, + ) -> Vec> { + let mut multiples = Vec::new(); + multiples.push(self.constant_affine_point(C::GENERATOR_AFFINE)); + let mut cur = p.clone(); + for _pow in 0..WINDOW_SIZE { + for existing in multiples.clone() { + multiples.push(self.curve_add(&cur, &existing)); + } + cur = self.curve_double(&cur); + } + + multiples + } + + pub fn random_access_curve_points( + &mut self, + access_index: Target, + v: Vec>, + ) -> AffinePointTarget { + let num_limbs = v[0].x.value.num_limbs(); + let x_limbs: Vec> = (0..num_limbs) + .map(|i| v.iter().map(|p| p.x.value.limbs[i].0).collect()) + .collect(); + let y_limbs: Vec> = (0..num_limbs) + .map(|i| v.iter().map(|p| p.y.value.limbs[i].0).collect()) + .collect(); + + let selected_x_limbs: Vec<_> = x_limbs + .iter() + .map(|limbs| U32Target(self.random_access(access_index, limbs.clone()))) + .collect(); + let selected_y_limbs: Vec<_> = y_limbs + .iter() + .map(|limbs| U32Target(self.random_access(access_index, limbs.clone()))) + .collect(); + + let x = NonNativeTarget { + value: BigUintTarget { + limbs: selected_x_limbs, + }, + _phantom: PhantomData, + }; + let y = NonNativeTarget { + value: BigUintTarget { + limbs: selected_y_limbs, + }, + _phantom: PhantomData, + }; + AffinePointTarget { x, y } + } + + pub fn curve_scalar_mul_windowed( + &mut self, + p: &AffinePointTarget, + n: &NonNativeTarget, + ) -> AffinePointTarget { + let mut result = self.constant_affine_point(C::GENERATOR_AFFINE); + + let precomputation = self.precompute_window(p); + + let windows = self.split_nonnative_to_4_bit_limbs(n); + let m = C::ScalarField::BITS / WINDOW_SIZE; + for i in (0..m).rev() { + result = self.curve_double(&result); + let window = windows[i]; + + let to_add = self.random_access_curve_points(window, precomputation.clone()); + result = self.curve_add(&result, &to_add); + } + + result + } +} + +#[cfg(test)] +mod tests { + use std::ops::Neg; + + use anyhow::Result; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + use rand::Rng; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::secp256k1::Secp256K1; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use crate::plonk::verifier::verify; + + #[test] + fn test_random_access_curve_points() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let num_points = 8; + let points: Vec<_> = (0..num_points).map(|_| { + let g = (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + builder.constant_affine_point(g) + }).collect(); + + let mut rng = rand::thread_rng(); + let mut access_index = rng.gen::() % 8; + + let access_index_target = builder.constant(F::from_canonical_usize(access_index)); + let selected = builder.random_access_curve_points(access_index_target, points.clone()); + let expected = points[access_index].clone(); + builder.connect_affine_point(&selected, &expected); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_curve_mul_windowed() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let five = Secp256K1Scalar::from_canonical_usize(5); + let neg_five = five.neg(); + let neg_five_scalar = CurveScalar::(neg_five); + let neg_five_g = (neg_five_scalar * g.to_projective()).to_affine(); + let neg_five_g_expected = builder.constant_affine_point(neg_five_g); + builder.curve_assert_valid(&neg_five_g_expected); + + let g_target = builder.constant_affine_point(g); + let neg_five_target = builder.constant_nonnative(neg_five); + let neg_five_g_actual = builder.curve_scalar_mul_windowed(&g_target, &neg_five_target); + builder.curve_assert_valid(&neg_five_g_actual); + + builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } +} diff --git a/plonky2/src/gadgets/mod.rs b/plonky2/src/gadgets/mod.rs index 95e46b42..d9c93db3 100644 --- a/plonky2/src/gadgets/mod.rs +++ b/plonky2/src/gadgets/mod.rs @@ -3,6 +3,7 @@ pub mod arithmetic_extension; pub mod arithmetic_u32; pub mod biguint; pub mod curve; +pub mod curve_windowed_mul; // pub mod curve_msm; pub mod ecdsa; pub mod glv; From 64a09616e273c79a0359bc9d668135e9277faf9d Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 11 Feb 2022 09:01:08 -0800 Subject: [PATCH 15/56] fmt --- plonky2/src/gadgets/curve_windowed_mul.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/plonky2/src/gadgets/curve_windowed_mul.rs b/plonky2/src/gadgets/curve_windowed_mul.rs index 77bc6f3f..bbb0a991 100644 --- a/plonky2/src/gadgets/curve_windowed_mul.rs +++ b/plonky2/src/gadgets/curve_windowed_mul.rs @@ -122,10 +122,13 @@ mod tests { let mut builder = CircuitBuilder::::new(config); let num_points = 8; - let points: Vec<_> = (0..num_points).map(|_| { - let g = (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); - builder.constant_affine_point(g) - }).collect(); + let points: Vec<_> = (0..num_points) + .map(|_| { + let g = (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE) + .to_affine(); + builder.constant_affine_point(g) + }) + .collect(); let mut rng = rand::thread_rng(); let mut access_index = rng.gen::() % 8; @@ -152,7 +155,8 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let g = (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let g = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); let five = Secp256K1Scalar::from_canonical_usize(5); let neg_five = five.neg(); let neg_five_scalar = CurveScalar::(neg_five); From 23cfe910799c88ab9ac2d4a02d977cd7944d52ac Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 11 Feb 2022 09:02:21 -0800 Subject: [PATCH 16/56] fix --- plonky2/src/gadgets/curve_windowed_mul.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plonky2/src/gadgets/curve_windowed_mul.rs b/plonky2/src/gadgets/curve_windowed_mul.rs index bbb0a991..0479e142 100644 --- a/plonky2/src/gadgets/curve_windowed_mul.rs +++ b/plonky2/src/gadgets/curve_windowed_mul.rs @@ -131,7 +131,7 @@ mod tests { .collect(); let mut rng = rand::thread_rng(); - let mut access_index = rng.gen::() % 8; + let access_index = rng.gen::() % 8; let access_index_target = builder.constant(F::from_canonical_usize(access_index)); let selected = builder.random_access_curve_points(access_index_target, points.clone()); From 8bab62b83d73ae8e9f8d26533cbfb9c23118422b Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 11 Feb 2022 09:06:54 -0800 Subject: [PATCH 17/56] fix --- plonky2/src/gadgets/curve_windowed_mul.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/plonky2/src/gadgets/curve_windowed_mul.rs b/plonky2/src/gadgets/curve_windowed_mul.rs index 0479e142..9abc4ea3 100644 --- a/plonky2/src/gadgets/curve_windowed_mul.rs +++ b/plonky2/src/gadgets/curve_windowed_mul.rs @@ -20,8 +20,7 @@ impl, const D: usize> CircuitBuilder { &mut self, p: &AffinePointTarget, ) -> Vec> { - let mut multiples = Vec::new(); - multiples.push(self.constant_affine_point(C::GENERATOR_AFFINE)); + let mut multiples = vec![self.constant_affine_point(C::GENERATOR_AFFINE)]; let mut cur = p.clone(); for _pow in 0..WINDOW_SIZE { for existing in multiples.clone() { From 978e2ee974bc91e6fdbdc8d6639f1144c259b15e Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 11 Feb 2022 11:41:06 -0800 Subject: [PATCH 18/56] conditional add (doesn't work yet) --- plonky2/src/gadgets/curve.rs | 50 +++++++++++++++++++++++ plonky2/src/gadgets/curve_windowed_mul.rs | 4 +- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index 908ed66e..83309196 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -4,6 +4,7 @@ use plonky2_field::field_types::Field; use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::RichField; +use crate::iop::target::BoolTarget; use crate::plonk::circuit_builder::CircuitBuilder; /// A Target representing an affine point on the curve `C`. We use incomplete arithmetic for efficiency, @@ -117,6 +118,23 @@ impl, const D: usize> CircuitBuilder { AffinePointTarget { x: x3, y: y3 } } + pub fn curve_conditional_add( + &mut self, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + b: BoolTarget, + ) -> AffinePointTarget { + let to_add_x = self.mul_nonnative_by_bool(&p2.x, b); + let to_add_y = self.mul_nonnative_by_bool(&p2.y, b); + let sum_x = self.add_nonnative(&p1.x, &to_add_x); + let sum_y = self.add_nonnative(&p1.y, &to_add_y); + + AffinePointTarget { + x: sum_x, + y: sum_y, + } + } + pub fn curve_scalar_mul( &mut self, p: &AffinePointTarget, @@ -295,6 +313,38 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + #[test] + fn test_curve_conditional_add() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let double_g = g.double(); + let g_plus_2g = (g + double_g).to_affine(); + let g_plus_2g_expected = builder.constant_affine_point(g_plus_2g); + + let g_expected = builder.constant_affine_point(g); + let double_g_target = builder.curve_double(&g_expected); + let t = builder._true(); + let f = builder._false(); + let g_plus_2g_actual = builder.curve_conditional_add(&g_expected, &double_g_target, t); + let g_actual = builder.curve_conditional_add(&g_expected, &double_g_target, f); + + builder.connect_affine_point(&g_plus_2g_expected, &g_plus_2g_actual); + builder.connect_affine_point(&g_expected, &g_actual); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } + #[test] #[ignore] fn test_curve_mul() -> Result<()> { diff --git a/plonky2/src/gadgets/curve_windowed_mul.rs b/plonky2/src/gadgets/curve_windowed_mul.rs index 9abc4ea3..f9915043 100644 --- a/plonky2/src/gadgets/curve_windowed_mul.rs +++ b/plonky2/src/gadgets/curve_windowed_mul.rs @@ -120,7 +120,7 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let num_points = 8; + let num_points = 16; let points: Vec<_> = (0..num_points) .map(|_| { let g = (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE) @@ -130,7 +130,7 @@ mod tests { .collect(); let mut rng = rand::thread_rng(); - let access_index = rng.gen::() % 8; + let access_index = rng.gen::() % num_points; let access_index_target = builder.constant(F::from_canonical_usize(access_index)); let selected = builder.random_access_curve_points(access_index_target, points.clone()); From 134a04220ddfeafd574a03b1f5ce90ef26991d52 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 11 Feb 2022 12:34:09 -0800 Subject: [PATCH 19/56] is_equal function --- plonky2/src/gadgets/arithmetic.rs | 40 +++++++++++++++++++++++++++++++ plonky2/src/gadgets/curve.rs | 17 ++++++++----- 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/plonky2/src/gadgets/arithmetic.rs b/plonky2/src/gadgets/arithmetic.rs index 11fc57bf..3c43d97a 100644 --- a/plonky2/src/gadgets/arithmetic.rs +++ b/plonky2/src/gadgets/arithmetic.rs @@ -6,7 +6,9 @@ use plonky2_field::field_types::Field64; use crate::gates::arithmetic_base::ArithmeticGate; use crate::gates::exponentiation::ExponentiationGate; use crate::hash::hash_types::RichField; +use crate::iop::generator::{SimpleGenerator, GeneratedValues}; use crate::iop::target::{BoolTarget, Target}; +use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; impl, const D: usize> CircuitBuilder { @@ -323,6 +325,44 @@ impl, const D: usize> CircuitBuilder { let res = self.sub(one, b.target); BoolTarget::new_unsafe(res) } + + pub fn is_equal(&mut self, x: Target, y: Target) -> BoolTarget { + let b = self.add_virtual_bool_target(); + self.add_simple_generator(EqualityGenerator { + x, + y, + b, + }); + + let diff = self.sub(x, y); + let result = self.mul(b.target, diff); + + let zero = self.zero(); + self.connect(zero, result); + + b + } +} + +#[derive(Debug)] +struct EqualityGenerator { + x: Target, + y: Target, + b: BoolTarget, +} + +impl SimpleGenerator for EqualityGenerator +{ + fn dependencies(&self) -> Vec { + vec![self.x, self.y] + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let x = witness.get_target(self.x); + let y = witness.get_target(self.y); + + out_buffer.set_bool_target(self.b, x == y); + } } /// Represents a base arithmetic operation in the circuit. Used to memoize results. diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index 83309196..d2a298a8 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -124,14 +124,19 @@ impl, const D: usize> CircuitBuilder { p2: &AffinePointTarget, b: BoolTarget, ) -> AffinePointTarget { - let to_add_x = self.mul_nonnative_by_bool(&p2.x, b); - let to_add_y = self.mul_nonnative_by_bool(&p2.y, b); - let sum_x = self.add_nonnative(&p1.x, &to_add_x); - let sum_y = self.add_nonnative(&p1.y, &to_add_y); + let not_b = self.not(b); + let sum = self.curve_add(p1, p2); + let x_if_true = self.mul_nonnative_by_bool(&sum.x, b); + let y_if_true = self.mul_nonnative_by_bool(&sum.y, b); + let x_if_false = self.mul_nonnative_by_bool(&p1.x, not_b); + let y_if_false = self.mul_nonnative_by_bool(&p1.y, not_b); + + let x = self.add_nonnative(&x_if_true, &x_if_false); + let y = self.add_nonnative(&y_if_true, &y_if_false); AffinePointTarget { - x: sum_x, - y: sum_y, + x, + y, } } From 84edb55b63db2e00be50469e5c681c8a239097b5 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 11 Feb 2022 12:34:21 -0800 Subject: [PATCH 20/56] fmt --- plonky2/src/gadgets/arithmetic.rs | 11 +++-------- plonky2/src/gadgets/curve.rs | 7 ++----- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/plonky2/src/gadgets/arithmetic.rs b/plonky2/src/gadgets/arithmetic.rs index 3c43d97a..7f40cdef 100644 --- a/plonky2/src/gadgets/arithmetic.rs +++ b/plonky2/src/gadgets/arithmetic.rs @@ -6,7 +6,7 @@ use plonky2_field::field_types::Field64; use crate::gates::arithmetic_base::ArithmeticGate; use crate::gates::exponentiation::ExponentiationGate; use crate::hash::hash_types::RichField; -use crate::iop::generator::{SimpleGenerator, GeneratedValues}; +use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; @@ -328,11 +328,7 @@ impl, const D: usize> CircuitBuilder { pub fn is_equal(&mut self, x: Target, y: Target) -> BoolTarget { let b = self.add_virtual_bool_target(); - self.add_simple_generator(EqualityGenerator { - x, - y, - b, - }); + self.add_simple_generator(EqualityGenerator { x, y, b }); let diff = self.sub(x, y); let result = self.mul(b.target, diff); @@ -351,8 +347,7 @@ struct EqualityGenerator { b: BoolTarget, } -impl SimpleGenerator for EqualityGenerator -{ +impl SimpleGenerator for EqualityGenerator { fn dependencies(&self) -> Vec { vec![self.x, self.y] } diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index d2a298a8..0a0650e9 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -133,11 +133,8 @@ impl, const D: usize> CircuitBuilder { let x = self.add_nonnative(&x_if_true, &x_if_false); let y = self.add_nonnative(&y_if_true, &y_if_false); - - AffinePointTarget { - x, - y, - } + + AffinePointTarget { x, y } } pub fn curve_scalar_mul( From 3787f3be22c759a7f5ef2f23fef444912118390e Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 11 Feb 2022 12:35:36 -0800 Subject: [PATCH 21/56] conditional add --- plonky2/src/gadgets/curve_windowed_mul.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/plonky2/src/gadgets/curve_windowed_mul.rs b/plonky2/src/gadgets/curve_windowed_mul.rs index f9915043..57d8c558 100644 --- a/plonky2/src/gadgets/curve_windowed_mul.rs +++ b/plonky2/src/gadgets/curve_windowed_mul.rs @@ -77,6 +77,7 @@ impl, const D: usize> CircuitBuilder { let mut result = self.constant_affine_point(C::GENERATOR_AFFINE); let precomputation = self.precompute_window(p); + let zero = self.zero(); let windows = self.split_nonnative_to_4_bit_limbs(n); let m = C::ScalarField::BITS / WINDOW_SIZE; @@ -85,7 +86,9 @@ impl, const D: usize> CircuitBuilder { let window = windows[i]; let to_add = self.random_access_curve_points(window, precomputation.clone()); - result = self.curve_add(&result, &to_add); + let is_zero = self.is_equal(window, zero); + let should_add = self.not(is_zero); + result = self.curve_conditional_add(&result, &to_add, should_add); } result From ad1aa4ae10d775fddf7c5a9b1345b417103f8417 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 11 Feb 2022 12:54:41 -0800 Subject: [PATCH 22/56] fixed is_equal --- plonky2/src/gadgets/arithmetic.rs | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/plonky2/src/gadgets/arithmetic.rs b/plonky2/src/gadgets/arithmetic.rs index 7f40cdef..3fee2ecd 100644 --- a/plonky2/src/gadgets/arithmetic.rs +++ b/plonky2/src/gadgets/arithmetic.rs @@ -327,16 +327,23 @@ impl, const D: usize> CircuitBuilder { } pub fn is_equal(&mut self, x: Target, y: Target) -> BoolTarget { - let b = self.add_virtual_bool_target(); - self.add_simple_generator(EqualityGenerator { x, y, b }); + let zero = self.zero(); + + let equal = self.add_virtual_bool_target(); + let not_equal = self.not(equal); + let inv = self.add_virtual_target(); + self.add_simple_generator(EqualityGenerator { x, y, equal, inv }); let diff = self.sub(x, y); - let result = self.mul(b.target, diff); + let not_equal_check = self.mul(equal.target, diff); - let zero = self.zero(); - self.connect(zero, result); + let diff_normalized = self.mul(diff, inv); + let equal_check = self.sub(diff_normalized, not_equal.target); - b + self.connect(not_equal_check, zero); + self.connect(equal_check, zero); + + equal } } @@ -344,7 +351,8 @@ impl, const D: usize> CircuitBuilder { struct EqualityGenerator { x: Target, y: Target, - b: BoolTarget, + equal: BoolTarget, + inv: Target, } impl SimpleGenerator for EqualityGenerator { @@ -356,7 +364,14 @@ impl SimpleGenerator for EqualityGenerator { let x = witness.get_target(self.x); let y = witness.get_target(self.y); - out_buffer.set_bool_target(self.b, x == y); + let inv = if x != y { + (x - y).inverse() + } else { + F::ZERO + }; + + out_buffer.set_bool_target(self.equal, x == y); + out_buffer.set_target(self.inv, inv); } } From f67e12ee64afa4d12a985c6e41dc65cce98960ab Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 11 Feb 2022 12:54:47 -0800 Subject: [PATCH 23/56] fmt --- plonky2/src/gadgets/arithmetic.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/plonky2/src/gadgets/arithmetic.rs b/plonky2/src/gadgets/arithmetic.rs index 3fee2ecd..70cd24ad 100644 --- a/plonky2/src/gadgets/arithmetic.rs +++ b/plonky2/src/gadgets/arithmetic.rs @@ -364,11 +364,7 @@ impl SimpleGenerator for EqualityGenerator { let x = witness.get_target(self.x); let y = witness.get_target(self.y); - let inv = if x != y { - (x - y).inverse() - } else { - F::ZERO - }; + let inv = if x != y { (x - y).inverse() } else { F::ZERO }; out_buffer.set_bool_target(self.equal, x == y); out_buffer.set_target(self.inv, inv); From 12d5239be6d74644132784a2dd955f546b7d6531 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 11 Feb 2022 13:24:46 -0800 Subject: [PATCH 24/56] fix --- plonky2/src/gadgets/curve.rs | 10 ++++++++++ plonky2/src/gadgets/curve_windowed_mul.rs | 4 ++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index 0a0650e9..66d2d1b1 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -95,6 +95,16 @@ impl, const D: usize> CircuitBuilder { AffinePointTarget { x: x3, y: y3 } } + pub fn curve_repeated_double(&mut self, p: &AffinePointTarget, n: usize) -> AffinePointTarget { + let mut result = p.clone(); + + for _ in 0..n { + result = self.curve_double(&result); + } + + result + } + // Add two points, which are assumed to be non-equal. pub fn curve_add( &mut self, diff --git a/plonky2/src/gadgets/curve_windowed_mul.rs b/plonky2/src/gadgets/curve_windowed_mul.rs index 57d8c558..002b6ec4 100644 --- a/plonky2/src/gadgets/curve_windowed_mul.rs +++ b/plonky2/src/gadgets/curve_windowed_mul.rs @@ -82,7 +82,7 @@ impl, const D: usize> CircuitBuilder { let windows = self.split_nonnative_to_4_bit_limbs(n); let m = C::ScalarField::BITS / WINDOW_SIZE; for i in (0..m).rev() { - result = self.curve_double(&result); + result = self.curve_repeated_double(&result, WINDOW_SIZE); let window = windows[i]; let to_add = self.random_access_curve_points(window, precomputation.clone()); @@ -147,7 +147,7 @@ mod tests { } #[test] - fn test_curve_mul_windowed() -> Result<()> { + fn test_curve_windowed_mul() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; From a89b306cf8c1468d7092c8826ef27cadd1117d53 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 11 Feb 2022 13:24:56 -0800 Subject: [PATCH 25/56] fmt --- plonky2/src/gadgets/curve.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index 66d2d1b1..2b27c120 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -95,7 +95,11 @@ impl, const D: usize> CircuitBuilder { AffinePointTarget { x: x3, y: y3 } } - pub fn curve_repeated_double(&mut self, p: &AffinePointTarget, n: usize) -> AffinePointTarget { + pub fn curve_repeated_double( + &mut self, + p: &AffinePointTarget, + n: usize, + ) -> AffinePointTarget { let mut result = p.clone(); for _ in 0..n { From f6f7e5519138f78a2f7ff37dcefade6eb175e445 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 16 Feb 2022 11:31:26 -0800 Subject: [PATCH 26/56] windowed mul fixes...... --- plonky2/src/gadgets/arithmetic.rs | 10 ++++++++++ plonky2/src/gadgets/curve_windowed_mul.rs | 21 ++++++++++++++++++++- plonky2/src/gadgets/nonnative.rs | 12 ++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/plonky2/src/gadgets/arithmetic.rs b/plonky2/src/gadgets/arithmetic.rs index 70cd24ad..b7df3726 100644 --- a/plonky2/src/gadgets/arithmetic.rs +++ b/plonky2/src/gadgets/arithmetic.rs @@ -326,6 +326,16 @@ impl, const D: usize> CircuitBuilder { BoolTarget::new_unsafe(res) } + pub fn and(&mut self, b1: BoolTarget, b2: BoolTarget) -> BoolTarget { + BoolTarget::new_unsafe(self.mul(b1.target, b2.target)) + } + + pub fn _if(&mut self, b: BoolTarget, x: Target, y: Target) -> Target { + let not_b = self.not(b); + let maybe_x = self.mul(b.target, x); + self.mul_add(not_b.target, y, maybe_x) + } + pub fn is_equal(&mut self, x: Target, y: Target) -> BoolTarget { let zero = self.zero(); diff --git a/plonky2/src/gadgets/curve_windowed_mul.rs b/plonky2/src/gadgets/curve_windowed_mul.rs index 002b6ec4..b8b0f804 100644 --- a/plonky2/src/gadgets/curve_windowed_mul.rs +++ b/plonky2/src/gadgets/curve_windowed_mul.rs @@ -9,7 +9,7 @@ use crate::gadgets::biguint::BigUintTarget; use crate::gadgets::curve::AffinePointTarget; use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::RichField; -use crate::iop::target::Target; +use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; const WINDOW_SIZE: usize = 4; @@ -69,12 +69,25 @@ impl, const D: usize> CircuitBuilder { AffinePointTarget { x, y } } + pub fn if_affine_point( + &mut self, + b: BoolTarget, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + ) -> AffinePointTarget { + let new_x = self.if_nonnative(b, &p1.x, &p2.x); + let new_y = self.if_nonnative(b, &p1.y, &p2.y); + AffinePointTarget { x: new_x, y: new_y } + } + pub fn curve_scalar_mul_windowed( &mut self, p: &AffinePointTarget, n: &NonNativeTarget, ) -> AffinePointTarget { let mut result = self.constant_affine_point(C::GENERATOR_AFFINE); + let mut to_subtract = self.constant_affine_point(C::GENERATOR_AFFINE); + let mut to_subtract_grows = self._true(); let precomputation = self.precompute_window(p); let zero = self.zero(); @@ -83,10 +96,15 @@ impl, const D: usize> CircuitBuilder { let m = C::ScalarField::BITS / WINDOW_SIZE; for i in (0..m).rev() { result = self.curve_repeated_double(&result, WINDOW_SIZE); + + let to_subtract_increased = self.curve_repeated_double(&to_subtract, WINDOW_SIZE); + to_subtract = self.if_affine_point(to_subtract_grows, &to_subtract_increased, &to_subtract); + let window = windows[i]; let to_add = self.random_access_curve_points(window, precomputation.clone()); let is_zero = self.is_equal(window, zero); + to_subtract_grows = self.and(to_subtract_grows, is_zero); let should_add = self.not(is_zero); result = self.curve_conditional_add(&result, &to_add, should_add); } @@ -173,6 +191,7 @@ mod tests { builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual); + println!("NUM GATES: {}", builder.num_gates()); let data = builder.build::(); let proof = data.prove(pw).unwrap(); diff --git a/plonky2/src/gadgets/nonnative.rs b/plonky2/src/gadgets/nonnative.rs index 3f8d29e8..44969ac9 100644 --- a/plonky2/src/gadgets/nonnative.rs +++ b/plonky2/src/gadgets/nonnative.rs @@ -106,6 +106,18 @@ impl, const D: usize> CircuitBuilder { } } + pub fn if_nonnative( + &mut self, + b: BoolTarget, + x: &NonNativeTarget, + y: &NonNativeTarget, + ) -> NonNativeTarget { + let not_b = self.not(b); + let maybe_x = self.mul_nonnative_by_bool(x, b); + let maybe_y = self.mul_nonnative_by_bool(y, not_b); + self.add_nonnative(&maybe_x, &maybe_y) + } + pub fn add_many_nonnative( &mut self, to_add: &[NonNativeTarget], From f77192ef6629e7d052f355710d38a6d0d051e73c Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Fri, 11 Feb 2022 15:16:11 -0800 Subject: [PATCH 27/56] fmt --- plonky2/src/gadgets/curve_windowed_mul.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plonky2/src/gadgets/curve_windowed_mul.rs b/plonky2/src/gadgets/curve_windowed_mul.rs index b8b0f804..e5e8f3a7 100644 --- a/plonky2/src/gadgets/curve_windowed_mul.rs +++ b/plonky2/src/gadgets/curve_windowed_mul.rs @@ -98,7 +98,8 @@ impl, const D: usize> CircuitBuilder { result = self.curve_repeated_double(&result, WINDOW_SIZE); let to_subtract_increased = self.curve_repeated_double(&to_subtract, WINDOW_SIZE); - to_subtract = self.if_affine_point(to_subtract_grows, &to_subtract_increased, &to_subtract); + to_subtract = + self.if_affine_point(to_subtract_grows, &to_subtract_increased, &to_subtract); let window = windows[i]; From e88564ce5e458ec24c1e888be24b68983e50b626 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 14 Feb 2022 11:03:44 -0800 Subject: [PATCH 28/56] correct point subtraction --- plonky2/src/gadgets/curve_windowed_mul.rs | 32 ++++++++++++++++------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/plonky2/src/gadgets/curve_windowed_mul.rs b/plonky2/src/gadgets/curve_windowed_mul.rs index e5e8f3a7..0a461cd7 100644 --- a/plonky2/src/gadgets/curve_windowed_mul.rs +++ b/plonky2/src/gadgets/curve_windowed_mul.rs @@ -1,16 +1,19 @@ use std::marker::PhantomData; +use num::BigUint; use plonky2_field::extension_field::Extendable; use plonky2_field::field_types::Field; -use crate::curve::curve_types::Curve; +use crate::curve::curve_types::{Curve, CurveScalar}; use crate::gadgets::arithmetic_u32::U32Target; use crate::gadgets::biguint::BigUintTarget; use crate::gadgets::curve::AffinePointTarget; use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::RichField; +use crate::hash::keccak::KeccakHash; use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::config::{GenericHashOut, Hasher}; const WINDOW_SIZE: usize = 4; @@ -85,9 +88,20 @@ impl, const D: usize> CircuitBuilder { p: &AffinePointTarget, n: &NonNativeTarget, ) -> AffinePointTarget { - let mut result = self.constant_affine_point(C::GENERATOR_AFFINE); - let mut to_subtract = self.constant_affine_point(C::GENERATOR_AFFINE); - let mut to_subtract_grows = self._true(); + let hash_0 = KeccakHash::<25>::hash(&[F::ZERO], false); + let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( + &GenericHashOut::::to_bytes(&hash_0), + )); + let starting_point = CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE; + let starting_point_multiplied = { + let mut cur = starting_point; + for _ in 0..C::ScalarField::BITS { + cur = cur.double(); + } + cur + }; + + let mut result = self.constant_affine_point(starting_point.to_affine()); let precomputation = self.precompute_window(p); let zero = self.zero(); @@ -96,20 +110,18 @@ impl, const D: usize> CircuitBuilder { let m = C::ScalarField::BITS / WINDOW_SIZE; for i in (0..m).rev() { result = self.curve_repeated_double(&result, WINDOW_SIZE); - - let to_subtract_increased = self.curve_repeated_double(&to_subtract, WINDOW_SIZE); - to_subtract = - self.if_affine_point(to_subtract_grows, &to_subtract_increased, &to_subtract); - let window = windows[i]; let to_add = self.random_access_curve_points(window, precomputation.clone()); let is_zero = self.is_equal(window, zero); - to_subtract_grows = self.and(to_subtract_grows, is_zero); let should_add = self.not(is_zero); result = self.curve_conditional_add(&result, &to_add, should_add); } + let to_subtract = self.constant_affine_point(starting_point_multiplied.to_affine()); + let to_add = self.curve_neg(&to_subtract); + result = self.curve_add(&result, &to_add); + result } } From 0140f7a3cfbf8eee4f7dd17cc6577e0aa24866c5 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 14 Feb 2022 14:05:21 -0800 Subject: [PATCH 29/56] fixes --- plonky2/src/gadgets/curve_windowed_mul.rs | 25 +++++++++++++---------- plonky2/src/plonk/circuit_builder.rs | 1 + 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/plonky2/src/gadgets/curve_windowed_mul.rs b/plonky2/src/gadgets/curve_windowed_mul.rs index 0a461cd7..d6212108 100644 --- a/plonky2/src/gadgets/curve_windowed_mul.rs +++ b/plonky2/src/gadgets/curve_windowed_mul.rs @@ -23,16 +23,20 @@ impl, const D: usize> CircuitBuilder { &mut self, p: &AffinePointTarget, ) -> Vec> { - let mut multiples = vec![self.constant_affine_point(C::GENERATOR_AFFINE)]; - let mut cur = p.clone(); - for _pow in 0..WINDOW_SIZE { - for existing in multiples.clone() { - multiples.push(self.curve_add(&cur, &existing)); - } - cur = self.curve_double(&cur); - } + let neg = { + let mut g = C::GENERATOR_AFFINE; + g.y = -g.y; + self.constant_affine_point(g) + }; - multiples + let mut multiples = vec![self.constant_affine_point(C::GENERATOR_AFFINE)]; + for i in 1..1 << WINDOW_SIZE { + multiples.push(self.curve_add(p, &multiples[i - 1])); + } + for i in 1..1 << WINDOW_SIZE { + multiples[i] = self.curve_add(&neg, &multiples[i]); + } + multiples } pub fn random_access_curve_points( @@ -107,8 +111,7 @@ impl, const D: usize> CircuitBuilder { let zero = self.zero(); let windows = self.split_nonnative_to_4_bit_limbs(n); - let m = C::ScalarField::BITS / WINDOW_SIZE; - for i in (0..m).rev() { + for i in (0..windows.len()).rev() { result = self.curve_repeated_double(&result, WINDOW_SIZE); let window = windows[i]; diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index ff975659..b3842539 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -150,6 +150,7 @@ impl, const D: usize> CircuitBuilder { /// generate the final witness (a grid of wire values), these virtual targets will go away. pub fn add_virtual_target(&mut self) -> Target { let index = self.virtual_target_index; + self.virtual_target_index += 1; Target::VirtualTarget { index } } From 1e3743f46c4f953b0fad04065970703129b7c92d Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 14 Feb 2022 14:05:26 -0800 Subject: [PATCH 30/56] fmt --- plonky2/src/gadgets/curve_windowed_mul.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plonky2/src/gadgets/curve_windowed_mul.rs b/plonky2/src/gadgets/curve_windowed_mul.rs index d6212108..08977dd6 100644 --- a/plonky2/src/gadgets/curve_windowed_mul.rs +++ b/plonky2/src/gadgets/curve_windowed_mul.rs @@ -36,7 +36,7 @@ impl, const D: usize> CircuitBuilder { for i in 1..1 << WINDOW_SIZE { multiples[i] = self.curve_add(&neg, &multiples[i]); } - multiples + multiples } pub fn random_access_curve_points( From 8ad193db174bcd6da1511b71bf5d8caf06cf0e04 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Mon, 14 Feb 2022 14:12:52 -0800 Subject: [PATCH 31/56] use windowed mul in GLV --- plonky2/src/gadgets/curve_windowed_mul.rs | 1 - plonky2/src/gadgets/glv.rs | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/plonky2/src/gadgets/curve_windowed_mul.rs b/plonky2/src/gadgets/curve_windowed_mul.rs index 08977dd6..2f0516dd 100644 --- a/plonky2/src/gadgets/curve_windowed_mul.rs +++ b/plonky2/src/gadgets/curve_windowed_mul.rs @@ -207,7 +207,6 @@ mod tests { builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual); - println!("NUM GATES: {}", builder.num_gates()); let data = builder.build::(); let proof = data.prove(pw).unwrap(); diff --git a/plonky2/src/gadgets/glv.rs b/plonky2/src/gadgets/glv.rs index 0a4afb3c..5ac3fb93 100644 --- a/plonky2/src/gadgets/glv.rs +++ b/plonky2/src/gadgets/glv.rs @@ -53,9 +53,8 @@ impl, const D: usize> CircuitBuilder { y: p.y.clone(), }; - // TODO: replace with MSM - let part1 = self.curve_scalar_mul(p, &k1); - let part2 = self.curve_scalar_mul(&sp, &k2); + let part1 = self.curve_scalar_mul_windowed(p, &k1); + let part2 = self.curve_scalar_mul_windowed(&sp, &k2); self.curve_add(&part1, &part2) } From 25555c15e0f99be470d9bf9a198e139ac37dc30f Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 16 Feb 2022 11:31:43 -0800 Subject: [PATCH 32/56] fixed native GLV; fixed precompute window; other fixes --- plonky2/src/curve/glv.rs | 49 ++++++++++++++++++----- plonky2/src/gadgets/curve.rs | 18 +++++++++ plonky2/src/gadgets/curve_windowed_mul.rs | 10 ++--- plonky2/src/gadgets/glv.rs | 28 +++++++++---- plonky2/src/gadgets/nonnative.rs | 12 ++++++ 5 files changed, 93 insertions(+), 24 deletions(-) diff --git a/plonky2/src/curve/glv.rs b/plonky2/src/curve/glv.rs index ded0d0b3..a2786e31 100644 --- a/plonky2/src/curve/glv.rs +++ b/plonky2/src/curve/glv.rs @@ -1,4 +1,5 @@ use num::rational::Ratio; +use num::BigUint; use plonky2_field::field_types::Field; use plonky2_field::secp256k1_base::Secp256K1Base; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; @@ -30,26 +31,46 @@ const A2: Secp256K1Scalar = Secp256K1Scalar([6323353552219852760, 14980988506747 const B2: Secp256K1Scalar = Secp256K1Scalar([16747920425669159701, 3496713202691238861, 0, 0]); -pub fn decompose_secp256k1_scalar(k: Secp256K1Scalar) -> (Secp256K1Scalar, Secp256K1Scalar) { +pub fn decompose_secp256k1_scalar( + k: Secp256K1Scalar, +) -> (Secp256K1Scalar, Secp256K1Scalar, bool, bool) { let p = Secp256K1Scalar::order(); let c1_biguint = Ratio::new(B2.to_biguint() * k.to_biguint(), p.clone()) .round() .to_integer(); let c1 = Secp256K1Scalar::from_biguint(c1_biguint); - let c2_biguint = Ratio::new(MINUS_B1.to_biguint() * k.to_biguint(), p) + let c2_biguint = Ratio::new(MINUS_B1.to_biguint() * k.to_biguint(), p.clone()) .round() .to_integer(); let c2 = Secp256K1Scalar::from_biguint(c2_biguint); - let k1 = k - c1 * A1 - c2 * A2; - let k2 = c1 * MINUS_B1 - c2 * B2; - debug_assert!(k1 + S * k2 == k); - (k1, k2) + let k1_raw = k - c1 * A1 - c2 * A2; + let k2_raw = c1 * MINUS_B1 - c2 * B2; + debug_assert!(k1_raw + S * k2_raw == k); + + let two = BigUint::from_slice(&[2]); + let k1_neg = k1_raw.to_biguint() > p.clone() / two.clone(); + let k1 = if k1_neg { + Secp256K1Scalar::from_biguint(p.clone() - k1_raw.to_biguint()) + } else { + k1_raw + }; + let k2_neg = k2_raw.to_biguint() > p.clone() / two.clone(); + let k2 = if k2_neg { + Secp256K1Scalar::from_biguint(p.clone() - k2_raw.to_biguint()) + } else { + k2_raw + }; + + (k1, k2, k1_neg, k2_neg) } pub fn glv_mul(p: ProjectivePoint, k: Secp256K1Scalar) -> ProjectivePoint { - let (k1, k2) = decompose_secp256k1_scalar(k); - assert!(k1 + S * k2 == k); + let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); + let one = Secp256K1Scalar::ONE; + /*let m1 = if k1_neg { -one } else { one }; + let m2 = if k2_neg { -one } else { one }; + assert!(k1 * m1 + S * k2 * m2 == k);*/ let p_affine = p.to_affine(); let sp = AffinePoint:: { @@ -58,7 +79,10 @@ pub fn glv_mul(p: ProjectivePoint, k: Secp256K1Scalar) -> ProjectiveP zero: p_affine.zero, }; - msm_parallel(&[k1, k2], &[p, sp.to_projective()], 5) + let first = if k1_neg { p.neg() } else { p }; + let second = if k2_neg { sp.to_projective().neg() } else { sp.to_projective() }; + + msm_parallel(&[k1, k2], &[first, second], 5) } #[cfg(test)] @@ -74,9 +98,12 @@ mod tests { #[test] fn test_glv_decompose() -> Result<()> { let k = Secp256K1Scalar::rand(); - let (k1, k2) = decompose_secp256k1_scalar(k); + let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); + let one = Secp256K1Scalar::ONE; + let m1 = if k1_neg { -one } else { one }; + let m2 = if k2_neg { -one } else { one }; - assert!(k1 + S * k2 == k); + assert!(k1 * m1 + S * k2 * m2 == k); Ok(()) } diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index 2b27c120..aa8c4112 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -71,6 +71,24 @@ impl, const D: usize> CircuitBuilder { } } + pub fn curve_conditional_neg( + &mut self, + p: &AffinePointTarget, + b: BoolTarget, + ) -> AffinePointTarget { + let not_b = self.not(b); + let neg = self.curve_neg(p); + let x_if_true = self.mul_nonnative_by_bool(&neg.x, b); + let y_if_true = self.mul_nonnative_by_bool(&neg.y, b); + let x_if_false = self.mul_nonnative_by_bool(&p.x, not_b); + let y_if_false = self.mul_nonnative_by_bool(&p.y, not_b); + + let x = self.add_nonnative(&x_if_true, &x_if_false); + let y = self.add_nonnative(&y_if_true, &y_if_false); + + AffinePointTarget { x, y } + } + pub fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget { let AffinePointTarget { x, y } = p; let double_y = self.add_nonnative(y, y); diff --git a/plonky2/src/gadgets/curve_windowed_mul.rs b/plonky2/src/gadgets/curve_windowed_mul.rs index 2f0516dd..f4cebe0e 100644 --- a/plonky2/src/gadgets/curve_windowed_mul.rs +++ b/plonky2/src/gadgets/curve_windowed_mul.rs @@ -18,18 +18,18 @@ use crate::plonk::config::{GenericHashOut, Hasher}; const WINDOW_SIZE: usize = 4; impl, const D: usize> CircuitBuilder { - // TODO: fix if p is the generator pub fn precompute_window( &mut self, p: &AffinePointTarget, ) -> Vec> { + let g = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine(); let neg = { - let mut g = C::GENERATOR_AFFINE; - g.y = -g.y; - self.constant_affine_point(g) + let mut neg = g; + neg.y = -neg.y; + self.constant_affine_point(neg) }; - let mut multiples = vec![self.constant_affine_point(C::GENERATOR_AFFINE)]; + let mut multiples = vec![self.constant_affine_point(g)]; for i in 1..1 << WINDOW_SIZE { multiples.push(self.curve_add(p, &multiples[i - 1])); } diff --git a/plonky2/src/gadgets/glv.rs b/plonky2/src/gadgets/glv.rs index 5ac3fb93..5614de55 100644 --- a/plonky2/src/gadgets/glv.rs +++ b/plonky2/src/gadgets/glv.rs @@ -10,7 +10,7 @@ use crate::gadgets::curve::AffinePointTarget; use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::RichField; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; -use crate::iop::target::Target; +use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; @@ -25,18 +25,24 @@ impl, const D: usize> CircuitBuilder { ) -> ( NonNativeTarget, NonNativeTarget, + BoolTarget, + BoolTarget, ) { - let k1 = self.add_virtual_nonnative_target::(); - let k2 = self.add_virtual_nonnative_target::(); + let k1 = self.add_virtual_nonnative_target_sized::(4); + let k2 = self.add_virtual_nonnative_target_sized::(4); + let k1_neg = self.add_virtual_bool_target(); + let k2_neg = self.add_virtual_bool_target(); self.add_simple_generator(GLVDecompositionGenerator:: { k: k.clone(), k1: k1.clone(), k2: k2.clone(), + k1_neg: k1_neg.clone(), + k2_neg: k2_neg.clone(), _phantom: PhantomData, }); - (k1, k2) + (k1, k2, k1_neg, k2_neg) } pub fn glv_mul( @@ -44,7 +50,7 @@ impl, const D: usize> CircuitBuilder { p: &AffinePointTarget, k: &NonNativeTarget, ) -> AffinePointTarget { - let (k1, k2) = self.decompose_secp256k1_scalar(k); + let (k1, k2, k1_neg, k2_neg) = self.decompose_secp256k1_scalar(k); let beta = self.secp256k1_glv_beta(); let beta_px = self.mul_nonnative(&beta, &p.x); @@ -54,9 +60,11 @@ impl, const D: usize> CircuitBuilder { }; let part1 = self.curve_scalar_mul_windowed(p, &k1); + let part1_neg = self.curve_conditional_neg(&part1, k1_neg); let part2 = self.curve_scalar_mul_windowed(&sp, &k2); + let part2_neg = self.curve_conditional_neg(&part2, k2_neg); - self.curve_add(&part1, &part2) + self.curve_add(&part1_neg, &part2_neg) } } @@ -65,6 +73,8 @@ struct GLVDecompositionGenerator, const D: usize> { k: NonNativeTarget, k1: NonNativeTarget, k2: NonNativeTarget, + k1_neg: BoolTarget, + k2_neg: BoolTarget, _phantom: PhantomData, } @@ -77,10 +87,12 @@ impl, const D: usize> SimpleGenerator fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { let k = witness.get_nonnative_target(self.k.clone()); - let (k1, k2) = decompose_secp256k1_scalar(k); + let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); out_buffer.set_nonnative_target(self.k1.clone(), k1); out_buffer.set_nonnative_target(self.k2.clone(), k2); + out_buffer.set_bool_target(self.k1_neg.clone(), k1_neg); + out_buffer.set_bool_target(self.k2_neg.clone(), k2_neg); } } @@ -100,7 +112,7 @@ mod tests { use crate::plonk::verifier::verify; #[test] - fn test_glv() -> Result<()> { + fn test_glv_gadget() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; diff --git a/plonky2/src/gadgets/nonnative.rs b/plonky2/src/gadgets/nonnative.rs index 44969ac9..046931d2 100644 --- a/plonky2/src/gadgets/nonnative.rs +++ b/plonky2/src/gadgets/nonnative.rs @@ -63,6 +63,18 @@ impl, const D: usize> CircuitBuilder { } } + pub fn add_virtual_nonnative_target_sized( + &mut self, + num_limbs: usize, + ) -> NonNativeTarget { + let value = self.add_virtual_biguint_target(num_limbs); + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + pub fn add_nonnative( &mut self, a: &NonNativeTarget, From 74cf5da8e0ce20f38a56035fd9cdb9f506dea4de Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 16 Feb 2022 11:07:56 -0800 Subject: [PATCH 33/56] clippy --- plonky2/src/curve/glv.rs | 14 +++++++++----- plonky2/src/gadgets/glv.rs | 8 ++++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/plonky2/src/curve/glv.rs b/plonky2/src/curve/glv.rs index a2786e31..591f61ea 100644 --- a/plonky2/src/curve/glv.rs +++ b/plonky2/src/curve/glv.rs @@ -55,9 +55,9 @@ pub fn decompose_secp256k1_scalar( } else { k1_raw }; - let k2_neg = k2_raw.to_biguint() > p.clone() / two.clone(); + let k2_neg = k2_raw.to_biguint() > p.clone() / two; let k2 = if k2_neg { - Secp256K1Scalar::from_biguint(p.clone() - k2_raw.to_biguint()) + Secp256K1Scalar::from_biguint(p - k2_raw.to_biguint()) } else { k2_raw }; @@ -67,8 +67,8 @@ pub fn decompose_secp256k1_scalar( pub fn glv_mul(p: ProjectivePoint, k: Secp256K1Scalar) -> ProjectivePoint { let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); - let one = Secp256K1Scalar::ONE; - /*let m1 = if k1_neg { -one } else { one }; + /*let one = Secp256K1Scalar::ONE; + let m1 = if k1_neg { -one } else { one }; let m2 = if k2_neg { -one } else { one }; assert!(k1 * m1 + S * k2 * m2 == k);*/ @@ -80,7 +80,11 @@ pub fn glv_mul(p: ProjectivePoint, k: Secp256K1Scalar) -> ProjectiveP }; let first = if k1_neg { p.neg() } else { p }; - let second = if k2_neg { sp.to_projective().neg() } else { sp.to_projective() }; + let second = if k2_neg { + sp.to_projective().neg() + } else { + sp.to_projective() + }; msm_parallel(&[k1, k2], &[first, second], 5) } diff --git a/plonky2/src/gadgets/glv.rs b/plonky2/src/gadgets/glv.rs index 5614de55..89013dd1 100644 --- a/plonky2/src/gadgets/glv.rs +++ b/plonky2/src/gadgets/glv.rs @@ -37,8 +37,8 @@ impl, const D: usize> CircuitBuilder { k: k.clone(), k1: k1.clone(), k2: k2.clone(), - k1_neg: k1_neg.clone(), - k2_neg: k2_neg.clone(), + k1_neg, + k2_neg, _phantom: PhantomData, }); @@ -91,8 +91,8 @@ impl, const D: usize> SimpleGenerator out_buffer.set_nonnative_target(self.k1.clone(), k1); out_buffer.set_nonnative_target(self.k2.clone(), k2); - out_buffer.set_bool_target(self.k1_neg.clone(), k1_neg); - out_buffer.set_bool_target(self.k2_neg.clone(), k2_neg); + out_buffer.set_bool_target(self.k1_neg, k1_neg); + out_buffer.set_bool_target(self.k2_neg, k2_neg); } } From 20fc5e2da559261c31d6b8f2c3097184286e6186 Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Wed, 16 Feb 2022 11:36:51 -0800 Subject: [PATCH 34/56] merge fixes --- plonky2/src/curve/glv.rs | 28 +- plonky2/src/gadgets/curve_windowed_mul.rs | 2 +- plonky2/src/gadgets/nonnative.rs | 9 +- plonky2/src/plonk/circuit_builder.rs | 440 ---------------------- 4 files changed, 24 insertions(+), 455 deletions(-) diff --git a/plonky2/src/curve/glv.rs b/plonky2/src/curve/glv.rs index 591f61ea..11172be0 100644 --- a/plonky2/src/curve/glv.rs +++ b/plonky2/src/curve/glv.rs @@ -1,6 +1,6 @@ use num::rational::Ratio; use num::BigUint; -use plonky2_field::field_types::Field; +use plonky2_field::field_types::{Field, PrimeField}; use plonky2_field::secp256k1_base::Secp256K1Base; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; @@ -35,13 +35,19 @@ pub fn decompose_secp256k1_scalar( k: Secp256K1Scalar, ) -> (Secp256K1Scalar, Secp256K1Scalar, bool, bool) { let p = Secp256K1Scalar::order(); - let c1_biguint = Ratio::new(B2.to_biguint() * k.to_biguint(), p.clone()) - .round() - .to_integer(); + let c1_biguint = Ratio::new( + B2.to_canonical_biguint() * k.to_canonical_biguint(), + p.clone(), + ) + .round() + .to_integer(); let c1 = Secp256K1Scalar::from_biguint(c1_biguint); - let c2_biguint = Ratio::new(MINUS_B1.to_biguint() * k.to_biguint(), p.clone()) - .round() - .to_integer(); + let c2_biguint = Ratio::new( + MINUS_B1.to_canonical_biguint() * k.to_canonical_biguint(), + p.clone(), + ) + .round() + .to_integer(); let c2 = Secp256K1Scalar::from_biguint(c2_biguint); let k1_raw = k - c1 * A1 - c2 * A2; @@ -49,15 +55,15 @@ pub fn decompose_secp256k1_scalar( debug_assert!(k1_raw + S * k2_raw == k); let two = BigUint::from_slice(&[2]); - let k1_neg = k1_raw.to_biguint() > p.clone() / two.clone(); + let k1_neg = k1_raw.to_canonical_biguint() > p.clone() / two.clone(); let k1 = if k1_neg { - Secp256K1Scalar::from_biguint(p.clone() - k1_raw.to_biguint()) + Secp256K1Scalar::from_biguint(p.clone() - k1_raw.to_canonical_biguint()) } else { k1_raw }; - let k2_neg = k2_raw.to_biguint() > p.clone() / two; + let k2_neg = k2_raw.to_canonical_biguint() > p.clone() / two; let k2 = if k2_neg { - Secp256K1Scalar::from_biguint(p - k2_raw.to_biguint()) + Secp256K1Scalar::from_biguint(p - k2_raw.to_canonical_biguint()) } else { k2_raw }; diff --git a/plonky2/src/gadgets/curve_windowed_mul.rs b/plonky2/src/gadgets/curve_windowed_mul.rs index f4cebe0e..879e1ade 100644 --- a/plonky2/src/gadgets/curve_windowed_mul.rs +++ b/plonky2/src/gadgets/curve_windowed_mul.rs @@ -92,7 +92,7 @@ impl, const D: usize> CircuitBuilder { p: &AffinePointTarget, n: &NonNativeTarget, ) -> AffinePointTarget { - let hash_0 = KeccakHash::<25>::hash(&[F::ZERO], false); + let hash_0 = KeccakHash::<25>::hash_no_pad(&[F::ZERO]); let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( &GenericHashOut::::to_bytes(&hash_0), )); diff --git a/plonky2/src/gadgets/nonnative.rs b/plonky2/src/gadgets/nonnative.rs index 046931d2..910915d0 100644 --- a/plonky2/src/gadgets/nonnative.rs +++ b/plonky2/src/gadgets/nonnative.rs @@ -31,7 +31,10 @@ impl, const D: usize> CircuitBuilder { } } - pub fn nonnative_to_biguint(&mut self, x: &NonNativeTarget) -> BigUintTarget { + pub fn nonnative_to_canonical_biguint( + &mut self, + x: &NonNativeTarget, + ) -> BigUintTarget { x.value.clone() } @@ -118,7 +121,7 @@ impl, const D: usize> CircuitBuilder { } } - pub fn if_nonnative( + pub fn if_nonnative( &mut self, b: BoolTarget, x: &NonNativeTarget, @@ -300,7 +303,7 @@ impl, const D: usize> CircuitBuilder { } pub fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { - let x_biguint = self.nonnative_to_biguint(x); + let x_biguint = self.nonnative_to_canonical_biguint(x); self.reduce(&x_biguint) } diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index b3842539..63f45fec 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -847,443 +847,3 @@ impl, const D: usize> CircuitBuilder { } } } -<<<<<<< HEAD -======= - -/// Various gate types can contain multiple copies in a single Gate. This helper struct lets a -/// CircuitBuilder track such gates that are currently being "filled up." -pub struct BatchedGates, const D: usize> { - /// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using - /// these constants with gate index `g` and already using `i` arithmetic operations. - pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, - pub(crate) free_base_arithmetic: HashMap<(F, F), (usize, usize)>, - - pub(crate) free_mul: HashMap, - - /// A map `b -> (g, i)` from `b` bits to an available random access gate of that size with gate - /// index `g` and already using `i` random accesses. - pub(crate) free_random_access: HashMap, - - /// `current_switch_gates[chunk_size - 1]` contains None if we have no switch gates with the value - /// chunk_size, and contains `(g, i, c)`, if the gate `g`, at index `i`, already contains `c` copies - /// of switches - pub(crate) current_switch_gates: Vec, usize, usize)>>, - - /// A map `n -> (g, i)` from `n` number of addends to an available `U32AddManyGate` of that size with gate - /// index `g` and already using `i` random accesses. - pub(crate) free_u32_add_many: HashMap, - - /// The `U32ArithmeticGate` currently being filled (so new u32 arithmetic operations will be added to this gate before creating a new one) - pub(crate) current_u32_arithmetic_gate: Option<(usize, usize)>, - /// The `U32SubtractionGate` currently being filled (so new u32 subtraction operations will be added to this gate before creating a new one) - pub(crate) current_u32_subtraction_gate: Option<(usize, usize)>, - - /// An available `ConstantGate` instance, if any. - pub(crate) free_constant: Option<(usize, usize)>, -} - -impl, const D: usize> BatchedGates { - pub fn new() -> Self { - Self { - free_arithmetic: HashMap::new(), - free_base_arithmetic: HashMap::new(), - free_mul: HashMap::new(), - free_random_access: HashMap::new(), - current_switch_gates: Vec::new(), - free_u32_add_many: HashMap::new(), - current_u32_arithmetic_gate: None, - current_u32_subtraction_gate: None, - free_constant: None, - } - } -} - -impl, const D: usize> CircuitBuilder { - /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. - /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index - /// `g` and the gate's `i`-th operation is available. - pub(crate) fn find_base_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) { - let (gate, i) = self - .batched_gates - .free_base_arithmetic - .get(&(const_0, const_1)) - .copied() - .unwrap_or_else(|| { - let gate = self.add_gate( - ArithmeticGate::new_from_config(&self.config), - vec![const_0, const_1], - ); - (gate, 0) - }); - - // Update `free_arithmetic` with new values. - if i < ArithmeticGate::num_ops(&self.config) - 1 { - self.batched_gates - .free_base_arithmetic - .insert((const_0, const_1), (gate, i + 1)); - } else { - self.batched_gates - .free_base_arithmetic - .remove(&(const_0, const_1)); - } - - (gate, i) - } - - /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. - /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index - /// `g` and the gate's `i`-th operation is available. - pub(crate) fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) { - let (gate, i) = self - .batched_gates - .free_arithmetic - .get(&(const_0, const_1)) - .copied() - .unwrap_or_else(|| { - let gate = self.add_gate( - ArithmeticExtensionGate::new_from_config(&self.config), - vec![const_0, const_1], - ); - (gate, 0) - }); - - // Update `free_arithmetic` with new values. - if i < ArithmeticExtensionGate::::num_ops(&self.config) - 1 { - self.batched_gates - .free_arithmetic - .insert((const_0, const_1), (gate, i + 1)); - } else { - self.batched_gates - .free_arithmetic - .remove(&(const_0, const_1)); - } - - (gate, i) - } - - /// Finds the last available arithmetic gate with the given constants or add one if there aren't any. - /// Returns `(g,i)` such that there is an arithmetic gate with the given constants at index - /// `g` and the gate's `i`-th operation is available. - pub(crate) fn find_mul_gate(&mut self, const_0: F) -> (usize, usize) { - let (gate, i) = self - .batched_gates - .free_mul - .get(&const_0) - .copied() - .unwrap_or_else(|| { - let gate = self.add_gate( - MulExtensionGate::new_from_config(&self.config), - vec![const_0], - ); - (gate, 0) - }); - - // Update `free_arithmetic` with new values. - if i < MulExtensionGate::::num_ops(&self.config) - 1 { - self.batched_gates.free_mul.insert(const_0, (gate, i + 1)); - } else { - self.batched_gates.free_mul.remove(&const_0); - } - - (gate, i) - } - - /// Finds the last available random access gate with the given `bits` or adds one if there aren't any. - /// Returns `(g,i)` such that there is a random access gate for the given `bits` at index - /// `g` and the gate's `i`-th random access is available. - pub(crate) fn find_random_access_gate(&mut self, bits: usize) -> (usize, usize) { - let (gate, i) = self - .batched_gates - .free_random_access - .get(&bits) - .copied() - .unwrap_or_else(|| { - let gate = self.add_gate( - RandomAccessGate::new_from_config(&self.config, bits), - vec![], - ); - (gate, 0) - }); - - // Update `free_random_access` with new values. - if i + 1 < RandomAccessGate::::new_from_config(&self.config, bits).num_copies { - self.batched_gates - .free_random_access - .insert(bits, (gate, i + 1)); - } else { - self.batched_gates.free_random_access.remove(&bits); - } - - (gate, i) - } - - pub fn find_switch_gate(&mut self, chunk_size: usize) -> (SwitchGate, usize, usize) { - if self.batched_gates.current_switch_gates.len() < chunk_size { - self.batched_gates.current_switch_gates.extend(vec![ - None; - chunk_size - - self - .batched_gates - .current_switch_gates - .len() - ]); - } - - let (gate, gate_index, next_copy) = - match self.batched_gates.current_switch_gates[chunk_size - 1].clone() { - None => { - let gate = SwitchGate::::new_from_config(&self.config, chunk_size); - let gate_index = self.add_gate(gate.clone(), vec![]); - (gate, gate_index, 0) - } - Some((gate, idx, next_copy)) => (gate, idx, next_copy), - }; - - let num_copies = gate.num_copies; - - if next_copy == num_copies - 1 { - self.batched_gates.current_switch_gates[chunk_size - 1] = None; - } else { - self.batched_gates.current_switch_gates[chunk_size - 1] = - Some((gate.clone(), gate_index, next_copy + 1)); - } - - (gate, gate_index, next_copy) - } - - /// Finds the last available U32 add-many gate with the given `num_addends` or adds one if there aren't any. - /// Returns `(g,i)` such that there is a `U32AddManyGate` for the given `num_addends` at index - /// `g` and the gate's `i`-th copy is available. - pub(crate) fn find_u32_add_many_gate(&mut self, num_addends: usize) -> (usize, usize) { - let (gate, i) = self - .batched_gates - .free_u32_add_many - .get(&num_addends) - .copied() - .unwrap_or_else(|| { - let gate = self.add_gate( - U32AddManyGate::new_from_config(&self.config, num_addends), - vec![], - ); - (gate, 0) - }); - - // Update `free_u32_add_many` with new values. - if i + 1 < U32AddManyGate::::new_from_config(&self.config, num_addends).num_ops { - self.batched_gates - .free_u32_add_many - .insert(num_addends, (gate, i + 1)); - } else { - self.batched_gates.free_u32_add_many.remove(&num_addends); - } - - (gate, i) - } - - pub(crate) fn find_u32_arithmetic_gate(&mut self) -> (usize, usize) { - let (gate_index, copy) = match self.batched_gates.current_u32_arithmetic_gate { - None => { - let gate = U32ArithmeticGate::new_from_config(&self.config); - let gate_index = self.add_gate(gate, vec![]); - (gate_index, 0) - } - Some((gate_index, copy)) => (gate_index, copy), - }; - - if copy == U32ArithmeticGate::::num_ops(&self.config) - 1 { - self.batched_gates.current_u32_arithmetic_gate = None; - } else { - self.batched_gates.current_u32_arithmetic_gate = Some((gate_index, copy + 1)); - } - - (gate_index, copy) - } - - pub(crate) fn find_u32_subtraction_gate(&mut self) -> (usize, usize) { - let (gate_index, copy) = match self.batched_gates.current_u32_subtraction_gate { - None => { - let gate = U32SubtractionGate::new_from_config(&self.config); - let gate_index = self.add_gate(gate, vec![]); - (gate_index, 0) - } - Some((gate_index, copy)) => (gate_index, copy), - }; - - if copy == U32SubtractionGate::::num_ops(&self.config) - 1 { - self.batched_gates.current_u32_subtraction_gate = None; - } else { - self.batched_gates.current_u32_subtraction_gate = Some((gate_index, copy + 1)); - } - - (gate_index, copy) - } - - /// Returns the gate index and copy index of a free `ConstantGate` slot, potentially adding a - /// new `ConstantGate` if needed. - fn constant_gate_instance(&mut self) -> (usize, usize) { - if self.batched_gates.free_constant.is_none() { - let num_consts = self.config.constant_gate_size; - // We will fill this `ConstantGate` with zero constants initially. - // These will be overwritten by `constant` as the gate instances are filled. - let gate = self.add_gate(ConstantGate { num_consts }, vec![F::ZERO; num_consts]); - self.batched_gates.free_constant = Some((gate, 0)); - } - - let (gate, instance) = self.batched_gates.free_constant.unwrap(); - if instance + 1 < self.config.constant_gate_size { - self.batched_gates.free_constant = Some((gate, instance + 1)); - } else { - self.batched_gates.free_constant = None; - } - (gate, instance) - } - - /// Fill the remaining unused arithmetic operations with zeros, so that all - /// `ArithmeticGate` are run. - fn fill_base_arithmetic_gates(&mut self) { - let zero = self.zero(); - for ((c0, c1), (_gate, i)) in self.batched_gates.free_base_arithmetic.clone() { - for _ in i..ArithmeticGate::num_ops(&self.config) { - // If we directly wire in zero, an optimization will skip doing anything and return - // zero. So we pass in a virtual target and connect it to zero afterward. - let dummy = self.add_virtual_target(); - self.arithmetic(c0, c1, dummy, dummy, dummy); - self.connect(dummy, zero); - } - } - assert!(self.batched_gates.free_base_arithmetic.is_empty()); - } - - /// Fill the remaining unused arithmetic operations with zeros, so that all - /// `ArithmeticExtensionGenerator`s are run. - fn fill_arithmetic_gates(&mut self) { - let zero = self.zero_extension(); - for ((c0, c1), (_gate, i)) in self.batched_gates.free_arithmetic.clone() { - for _ in i..ArithmeticExtensionGate::::num_ops(&self.config) { - // If we directly wire in zero, an optimization will skip doing anything and return - // zero. So we pass in a virtual target and connect it to zero afterward. - let dummy = self.add_virtual_extension_target(); - self.arithmetic_extension(c0, c1, dummy, dummy, dummy); - self.connect_extension(dummy, zero); - } - } - assert!(self.batched_gates.free_arithmetic.is_empty()); - } - - /// Fill the remaining unused arithmetic operations with zeros, so that all - /// `ArithmeticExtensionGenerator`s are run. - fn fill_mul_gates(&mut self) { - let zero = self.zero_extension(); - for (c0, (_gate, i)) in self.batched_gates.free_mul.clone() { - for _ in i..MulExtensionGate::::num_ops(&self.config) { - // If we directly wire in zero, an optimization will skip doing anything and return - // zero. So we pass in a virtual target and connect it to zero afterward. - let dummy = self.add_virtual_extension_target(); - self.arithmetic_extension(c0, F::ZERO, dummy, dummy, zero); - self.connect_extension(dummy, zero); - } - } - assert!(self.batched_gates.free_mul.is_empty()); - } - - /// Fill the remaining unused random access operations with zeros, so that all - /// `RandomAccessGenerator`s are run. - fn fill_random_access_gates(&mut self) { - let zero = self.zero(); - for (bits, (_, i)) in self.batched_gates.free_random_access.clone() { - let max_copies = - RandomAccessGate::::new_from_config(&self.config, bits).num_copies; - for _ in i..max_copies { - let result = self.random_access(zero, vec![zero; 1 << bits]); - self.connect(result, zero); - } - } - } - - /// Fill the remaining unused switch gates with dummy values, so that all - /// `SwitchGenerator`s are run. - fn fill_switch_gates(&mut self) { - let zero = self.zero(); - - for chunk_size in 1..=self.batched_gates.current_switch_gates.len() { - if let Some((gate, gate_index, mut copy)) = - self.batched_gates.current_switch_gates[chunk_size - 1].clone() - { - while copy < gate.num_copies { - for element in 0..chunk_size { - let wire_first_input = - Target::wire(gate_index, gate.wire_first_input(copy, element)); - let wire_second_input = - Target::wire(gate_index, gate.wire_second_input(copy, element)); - let wire_switch_bool = - Target::wire(gate_index, gate.wire_switch_bool(copy)); - self.connect(zero, wire_first_input); - self.connect(zero, wire_second_input); - self.connect(zero, wire_switch_bool); - } - copy += 1; - } - } - } - } - - /// Fill the remaining unused u32 add-many operations with zeros, so that all - /// `U32AddManyGenerator`s are run. - fn fill_u32_add_many_gates(&mut self) { - let zero = self.zero_u32(); - for (num_addends, (_, i)) in self.batched_gates.free_u32_add_many.clone() { - let max_copies = - U32AddManyGate::::new_from_config(&self.config, num_addends).num_ops; - for _ in i..max_copies { - let gate = U32AddManyGate::::new_from_config(&self.config, num_addends); - let (gate_index, copy) = self.find_u32_add_many_gate(num_addends); - - for j in 0..num_addends { - self.connect( - Target::wire(gate_index, gate.wire_ith_op_jth_addend(copy, j)), - zero.0, - ); - } - self.connect(Target::wire(gate_index, gate.wire_ith_carry(copy)), zero.0); - } - } - } - - /// Fill the remaining unused U32 arithmetic operations with zeros, so that all - /// `U32ArithmeticGenerator`s are run. - fn fill_u32_arithmetic_gates(&mut self) { - let zero = self.zero_u32(); - if let Some((_gate_index, copy)) = self.batched_gates.current_u32_arithmetic_gate { - for _ in copy..U32ArithmeticGate::::num_ops(&self.config) { - let dummy = self.add_virtual_u32_target(); - self.mul_add_u32(dummy, dummy, dummy); - self.connect_u32(dummy, zero); - } - } - } - - /// Fill the remaining unused U32 subtraction operations with zeros, so that all - /// `U32SubtractionGenerator`s are run. - fn fill_u32_subtraction_gates(&mut self) { - let zero = self.zero_u32(); - if let Some((_gate_index, copy)) = self.batched_gates.current_u32_subtraction_gate { - for _i in copy..U32SubtractionGate::::num_ops(&self.config) { - let dummy = self.add_virtual_u32_target(); - self.sub_u32(dummy, dummy, dummy); - self.connect_u32(dummy, zero); - } - } - } - - fn fill_batched_gates(&mut self) { - self.fill_arithmetic_gates(); - self.fill_base_arithmetic_gates(); - self.fill_mul_gates(); - self.fill_random_access_gates(); - self.fill_switch_gates(); - self.fill_u32_add_many_gates(); - self.fill_u32_arithmetic_gates(); - self.fill_u32_subtraction_gates(); - } -} ->>>>>>> aa48021 (windowed multiplication in circuit) From 772ff8d69ab7fb4e1d64dca10cec1a6739f0694e Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 25 Feb 2022 16:30:01 +0100 Subject: [PATCH 35/56] Works --- plonky2/src/gadgets/curve_msm.rs | 155 +++++++++++++++++++++++++ plonky2/src/gadgets/ecdsa.rs | 5 +- plonky2/src/gadgets/mod.rs | 1 + plonky2/src/gadgets/split_nonnative.rs | 11 ++ 4 files changed, 170 insertions(+), 2 deletions(-) create mode 100644 plonky2/src/gadgets/curve_msm.rs diff --git a/plonky2/src/gadgets/curve_msm.rs b/plonky2/src/gadgets/curve_msm.rs new file mode 100644 index 00000000..d0add327 --- /dev/null +++ b/plonky2/src/gadgets/curve_msm.rs @@ -0,0 +1,155 @@ +use num::BigUint; +use plonky2_field::extension_field::Extendable; + +use crate::curve::curve_types::{Curve, CurveScalar}; +use crate::field::field_types::Field; +use crate::gadgets::curve::AffinePointTarget; +use crate::gadgets::nonnative::NonNativeTarget; +use crate::hash::hash_types::RichField; +use crate::hash::keccak::KeccakHash; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::config::{GenericHashOut, Hasher}; + +impl, const D: usize> CircuitBuilder { + /// Computes `n*p + m*q`. + pub fn curve_msm( + &mut self, + p: &AffinePointTarget, + q: &AffinePointTarget, + n: &NonNativeTarget, + m: &NonNativeTarget, + ) -> AffinePointTarget { + let bits_n = self.split_nonnative_to_bits(n); + let bits_m = self.split_nonnative_to_bits(m); + assert_eq!(bits_n.len(), bits_m.len()); + + let sum = self.curve_add(p, q); + let precomputation = vec![p.clone(), p.clone(), q.clone(), sum]; + + let two = self.two(); + let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); + let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( + &GenericHashOut::::to_bytes(&hash_0), + )); + let starting_point = CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE; + let starting_point_multiplied = + (0..C::ScalarField::BITS).fold(starting_point, |acc, _| acc.double()); + + let zero = self.zero(); + let mut result = self.constant_affine_point(starting_point.to_affine()); + for (b_n, b_m) in bits_n.into_iter().zip(bits_m).rev() { + result = self.curve_double(&result); + let index = self.mul_add(two, b_m.target, b_n.target); + let r = self.random_access_curve_points(index, precomputation.clone()); + let is_zero = self.is_equal(index, zero); + let should_add = self.not(is_zero); + result = self.curve_conditional_add(&result, &r, should_add); + } + let to_subtract = self.constant_affine_point(starting_point_multiplied.to_affine()); + let to_add = self.curve_neg(&to_subtract); + result = self.curve_add(&result, &to_add); + + result + } +} + +#[cfg(test)] +mod tests { + use std::ops::Neg; + + use anyhow::Result; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::secp256k1::Secp256K1; + use crate::field::field_types::Field; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use crate::plonk::verifier::verify; + + #[test] + fn test_yo() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let p = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let q = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let n = Secp256K1Scalar::rand(); + let m = Secp256K1Scalar::rand(); + + let res = + (CurveScalar(n) * p.to_projective() + CurveScalar(m) * q.to_projective()).to_affine(); + let res_expected = builder.constant_affine_point(res); + builder.curve_assert_valid(&res_expected); + + let p_target = builder.constant_affine_point(p); + let q_target = builder.constant_affine_point(q); + let n_target = builder.constant_nonnative(n); + let m_target = builder.constant_nonnative(m); + + let res_target = builder.curve_msm(&p_target, &q_target, &n_target, &m_target); + builder.curve_assert_valid(&res_target); + + builder.connect_affine_point(&res_target, &res_expected); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_ya() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let p = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let q = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let n = Secp256K1Scalar::rand(); + let m = Secp256K1Scalar::rand(); + + let res = + (CurveScalar(n) * p.to_projective() + CurveScalar(m) * q.to_projective()).to_affine(); + let res_expected = builder.constant_affine_point(res); + builder.curve_assert_valid(&res_expected); + + let p_target = builder.constant_affine_point(p); + let q_target = builder.constant_affine_point(q); + let n_target = builder.constant_nonnative(n); + let m_target = builder.constant_nonnative(m); + + // let res0_target = builder.curve_scalar_mul_windowed(&p_target, &n_target); + // let res1_target = builder.curve_scalar_mul_windowed(&q_target, &m_target); + let res0_target = builder.curve_scalar_mul(&p_target, &n_target); + let res1_target = builder.curve_scalar_mul(&q_target, &m_target); + let res_target = builder.curve_add(&res0_target, &res1_target); + builder.curve_assert_valid(&res_target); + + builder.connect_affine_point(&res_target, &res_expected); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } +} diff --git a/plonky2/src/gadgets/ecdsa.rs b/plonky2/src/gadgets/ecdsa.rs index 0a95e189..5f4c4ff1 100644 --- a/plonky2/src/gadgets/ecdsa.rs +++ b/plonky2/src/gadgets/ecdsa.rs @@ -35,8 +35,8 @@ impl, const D: usize> CircuitBuilder { let u2 = self.mul_nonnative(&r, &c); let g = self.constant_affine_point(C::GENERATOR_AFFINE); - let point1 = self.curve_scalar_mul(&g, &u1); - let point2 = self.curve_scalar_mul(&pk.0, &u2); + let point1 = self.curve_scalar_mul_windowed(&g, &u1); + let point2 = self.curve_scalar_mul_windowed(&pk.0, &u2); let point = self.curve_add(&point1, &point2); let x = NonNativeTarget:: { @@ -97,6 +97,7 @@ mod tests { builder.verify_message(msg_target, sig_target, pk_target); + dbg!(builder.num_gates()); let data = builder.build::(); let proof = data.prove(pw).unwrap(); verify(proof, &data.verifier_only, &data.common) diff --git a/plonky2/src/gadgets/mod.rs b/plonky2/src/gadgets/mod.rs index d9c93db3..e35afeed 100644 --- a/plonky2/src/gadgets/mod.rs +++ b/plonky2/src/gadgets/mod.rs @@ -5,6 +5,7 @@ pub mod biguint; pub mod curve; pub mod curve_windowed_mul; // pub mod curve_msm; +pub mod curve_msm; pub mod ecdsa; pub mod glv; pub mod hash; diff --git a/plonky2/src/gadgets/split_nonnative.rs b/plonky2/src/gadgets/split_nonnative.rs index 70661506..18fc0264 100644 --- a/plonky2/src/gadgets/split_nonnative.rs +++ b/plonky2/src/gadgets/split_nonnative.rs @@ -35,6 +35,17 @@ impl, const D: usize> CircuitBuilder { .collect() } + pub fn split_nonnative_to_2_bit_limbs( + &mut self, + val: &NonNativeTarget, + ) -> Vec { + val.value + .limbs + .iter() + .flat_map(|&l| self.split_le_base::<4>(l.0, 16)) + .collect() + } + // Note: assumes its inputs are 4-bit limbs, and does not range-check. pub fn recombine_nonnative_4_bit_limbs( &mut self, From efb074b247bf6b7f86986cbe50fdb729f79f201d Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 25 Feb 2022 17:21:35 +0100 Subject: [PATCH 36/56] Works with 2 --- plonky2/src/gadgets/curve_msm.rs | 52 +++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/plonky2/src/gadgets/curve_msm.rs b/plonky2/src/gadgets/curve_msm.rs index d0add327..e8db6fdd 100644 --- a/plonky2/src/gadgets/curve_msm.rs +++ b/plonky2/src/gadgets/curve_msm.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use num::BigUint; use plonky2_field::extension_field::Extendable; @@ -19,27 +21,51 @@ impl, const D: usize> CircuitBuilder { n: &NonNativeTarget, m: &NonNativeTarget, ) -> AffinePointTarget { - let bits_n = self.split_nonnative_to_bits(n); - let bits_m = self.split_nonnative_to_bits(m); - assert_eq!(bits_n.len(), bits_m.len()); + let limbs_n = self.split_nonnative_to_2_bit_limbs(n); + let limbs_m = self.split_nonnative_to_2_bit_limbs(m); + assert_eq!(limbs_n.len(), limbs_m.len()); - let sum = self.curve_add(p, q); - let precomputation = vec![p.clone(), p.clone(), q.clone(), sum]; - - let two = self.two(); let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( &GenericHashOut::::to_bytes(&hash_0), )); let starting_point = CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE; + let starting_point_t = self.constant_affine_point(starting_point.to_affine()); + let neg = { + let mut neg = starting_point.to_affine(); + neg.y = -neg.y; + self.constant_affine_point(neg) + }; + + let mut precomputation = vec![p.clone(); 16]; + let mut cur_p = starting_point_t.clone(); + let mut cur_q = starting_point_t.clone(); + for i in 0..4 { + precomputation[i] = cur_p.clone(); + precomputation[4 * i] = cur_q.clone(); + cur_p = self.curve_add(&cur_p, p); + cur_q = self.curve_add(&cur_q, q); + } + for i in 1..4 { + precomputation[i] = self.curve_add(&precomputation[i], &neg); + precomputation[4 * i] = self.curve_add(&precomputation[4 * i], &neg); + } + for i in 1..4 { + for j in 1..4 { + precomputation[i + 4 * j] = + self.curve_add(&precomputation[i], &precomputation[4 * j]); + } + } + + let four = self.constant(F::from_canonical_usize(4)); let starting_point_multiplied = (0..C::ScalarField::BITS).fold(starting_point, |acc, _| acc.double()); let zero = self.zero(); let mut result = self.constant_affine_point(starting_point.to_affine()); - for (b_n, b_m) in bits_n.into_iter().zip(bits_m).rev() { - result = self.curve_double(&result); - let index = self.mul_add(two, b_m.target, b_n.target); + for (limb_n, limb_m) in limbs_n.into_iter().zip(limbs_m).rev() { + result = self.curve_repeated_double(&result, 2); + let index = self.mul_add(four, limb_m, limb_n); let r = self.random_access_curve_points(index, precomputation.clone()); let is_zero = self.is_equal(index, zero); let should_add = self.not(is_zero); @@ -137,10 +163,8 @@ mod tests { let n_target = builder.constant_nonnative(n); let m_target = builder.constant_nonnative(m); - // let res0_target = builder.curve_scalar_mul_windowed(&p_target, &n_target); - // let res1_target = builder.curve_scalar_mul_windowed(&q_target, &m_target); - let res0_target = builder.curve_scalar_mul(&p_target, &n_target); - let res1_target = builder.curve_scalar_mul(&q_target, &m_target); + let res0_target = builder.curve_scalar_mul_windowed(&p_target, &n_target); + let res1_target = builder.curve_scalar_mul_windowed(&q_target, &m_target); let res_target = builder.curve_add(&res0_target, &res1_target); builder.curve_assert_valid(&res_target); From 61af3a0de2dc85d336526b3014fa06d0d5765405 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 25 Feb 2022 19:39:30 +0100 Subject: [PATCH 37/56] Cleaning --- plonky2/src/gadgets/curve_msm.rs | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/plonky2/src/gadgets/curve_msm.rs b/plonky2/src/gadgets/curve_msm.rs index e8db6fdd..12f15306 100644 --- a/plonky2/src/gadgets/curve_msm.rs +++ b/plonky2/src/gadgets/curve_msm.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; - use num::BigUint; use plonky2_field::extension_field::Extendable; @@ -29,17 +27,17 @@ impl, const D: usize> CircuitBuilder { let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( &GenericHashOut::::to_bytes(&hash_0), )); - let starting_point = CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE; - let starting_point_t = self.constant_affine_point(starting_point.to_affine()); - let neg = { - let mut neg = starting_point.to_affine(); + let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); + let rando_t = self.constant_affine_point(rando); + let neg_rando = { + let mut neg = rando; neg.y = -neg.y; self.constant_affine_point(neg) }; let mut precomputation = vec![p.clone(); 16]; - let mut cur_p = starting_point_t.clone(); - let mut cur_q = starting_point_t.clone(); + let mut cur_p = rando_t.clone(); + let mut cur_q = rando_t.clone(); for i in 0..4 { precomputation[i] = cur_p.clone(); precomputation[4 * i] = cur_q.clone(); @@ -47,8 +45,8 @@ impl, const D: usize> CircuitBuilder { cur_q = self.curve_add(&cur_q, q); } for i in 1..4 { - precomputation[i] = self.curve_add(&precomputation[i], &neg); - precomputation[4 * i] = self.curve_add(&precomputation[4 * i], &neg); + precomputation[i] = self.curve_add(&precomputation[i], &neg_rando); + precomputation[4 * i] = self.curve_add(&precomputation[4 * i], &neg_rando); } for i in 1..4 { for j in 1..4 { @@ -59,10 +57,10 @@ impl, const D: usize> CircuitBuilder { let four = self.constant(F::from_canonical_usize(4)); let starting_point_multiplied = - (0..C::ScalarField::BITS).fold(starting_point, |acc, _| acc.double()); + (0..C::ScalarField::BITS).fold(rando, |acc, _| acc.double()); let zero = self.zero(); - let mut result = self.constant_affine_point(starting_point.to_affine()); + let mut result = rando_t; for (limb_n, limb_m) in limbs_n.into_iter().zip(limbs_m).rev() { result = self.curve_repeated_double(&result, 2); let index = self.mul_add(four, limb_m, limb_n); @@ -71,7 +69,7 @@ impl, const D: usize> CircuitBuilder { let should_add = self.not(is_zero); result = self.curve_conditional_add(&result, &r, should_add); } - let to_subtract = self.constant_affine_point(starting_point_multiplied.to_affine()); + let to_subtract = self.constant_affine_point(starting_point_multiplied); let to_add = self.curve_neg(&to_subtract); result = self.curve_add(&result, &to_add); @@ -81,7 +79,6 @@ impl, const D: usize> CircuitBuilder { #[cfg(test)] mod tests { - use std::ops::Neg; use anyhow::Result; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; @@ -96,7 +93,7 @@ mod tests { use crate::plonk::verifier::verify; #[test] - fn test_yo() -> Result<()> { + fn test_curve_msm() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; @@ -136,7 +133,7 @@ mod tests { } #[test] - fn test_ya() -> Result<()> { + fn test_naive_msm() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; From 74cf1d38870aa4b8343320f4872a74bc2e00acbd Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 1 Mar 2022 07:59:35 +0100 Subject: [PATCH 38/56] Minor improvement --- plonky2/src/gadgets/curve_msm.rs | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/plonky2/src/gadgets/curve_msm.rs b/plonky2/src/gadgets/curve_msm.rs index 12f15306..43a22da2 100644 --- a/plonky2/src/gadgets/curve_msm.rs +++ b/plonky2/src/gadgets/curve_msm.rs @@ -29,11 +29,7 @@ impl, const D: usize> CircuitBuilder { )); let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); let rando_t = self.constant_affine_point(rando); - let neg_rando = { - let mut neg = rando; - neg.y = -neg.y; - self.constant_affine_point(neg) - }; + let neg_rando = self.constant_affine_point(-rando); let mut precomputation = vec![p.clone(); 16]; let mut cur_p = rando_t.clone(); @@ -56,8 +52,6 @@ impl, const D: usize> CircuitBuilder { } let four = self.constant(F::from_canonical_usize(4)); - let starting_point_multiplied = - (0..C::ScalarField::BITS).fold(rando, |acc, _| acc.double()); let zero = self.zero(); let mut result = rando_t; @@ -69,8 +63,9 @@ impl, const D: usize> CircuitBuilder { let should_add = self.not(is_zero); result = self.curve_conditional_add(&result, &r, should_add); } - let to_subtract = self.constant_affine_point(starting_point_multiplied); - let to_add = self.curve_neg(&to_subtract); + let starting_point_multiplied = + (0..C::ScalarField::BITS).fold(rando, |acc, _| acc.double()); + let to_add = self.constant_affine_point(-starting_point_multiplied); result = self.curve_add(&result, &to_add); result From ba5b1f7278e27b54a361bd4f150b185eaf63930f Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 2 Mar 2022 10:27:20 +0100 Subject: [PATCH 39/56] Fix `set_biguint_target` --- plonky2/src/iop/witness.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/plonky2/src/iop/witness.rs b/plonky2/src/iop/witness.rs index de6d4a05..a013f811 100644 --- a/plonky2/src/iop/witness.rs +++ b/plonky2/src/iop/witness.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::iter::repeat; use itertools::Itertools; use num::{BigUint, FromPrimitive, Zero}; @@ -160,7 +161,11 @@ pub trait Witness { } fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) { - for (<, &l) in target.limbs.iter().zip(&value.to_u32_digits()) { + for (<, l) in target + .limbs + .iter() + .zip(value.to_u32_digits().into_iter().chain(repeat(0))) + { self.set_u32_target(lt, l); } } From 6f3ca6a0bc32d5aa9e68e32ccc89dbc7e2a0055c Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 2 Mar 2022 11:04:05 +0100 Subject: [PATCH 40/56] Fixed base works --- plonky2/src/gadgets/curve_msm.rs | 149 +++++++++++++++++++++- plonky2/src/gadgets/curve_windowed_mul.rs | 15 ++- plonky2/src/gadgets/glv.rs | 64 +++++++++- plonky2/src/plonk/circuit_builder.rs | 2 +- 4 files changed, 216 insertions(+), 14 deletions(-) diff --git a/plonky2/src/gadgets/curve_msm.rs b/plonky2/src/gadgets/curve_msm.rs index 43a22da2..aa5a8d34 100644 --- a/plonky2/src/gadgets/curve_msm.rs +++ b/plonky2/src/gadgets/curve_msm.rs @@ -1,7 +1,8 @@ +use itertools::Itertools; use num::BigUint; use plonky2_field::extension_field::Extendable; -use crate::curve::curve_types::{Curve, CurveScalar}; +use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; use crate::field::field_types::Field; use crate::gadgets::curve::AffinePointTarget; use crate::gadgets::nonnative::NonNativeTarget; @@ -70,18 +71,126 @@ impl, const D: usize> CircuitBuilder { result } + + // pub fn quad_curve_msm( + // &mut self, + // points: [AffinePointTarget; 4], + // scalars: [NonNativeTarget; 4], + // ) -> AffinePointTarget { + // let limbs = scalars + // .iter() + // .map(|s| self.split_nonnative_to_bits(n)) + // .collect_vec(); + // + // let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); + // let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( + // &GenericHashOut::::to_bytes(&hash_0), + // )); + // let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); + // let rando_t = self.constant_affine_point(rando); + // let neg_rando = self.constant_affine_point(-rando); + // + // let mut precomputation = vec![points[0].clone(); 16]; + // for i in 0..4 { + // precomputation[1 << i] = points[i].clone(); + // for j in 1..1 << (i - 1) {} + // } + // let mut cur_p = rando_t.clone(); + // let mut cur_q = rando_t.clone(); + // for i in 0..4 { + // precomputation[i] = cur_p.clone(); + // precomputation[4 * i] = cur_q.clone(); + // cur_p = self.curve_add(&cur_p, p); + // cur_q = self.curve_add(&cur_q, q); + // } + // for i in 1..4 { + // precomputation[i] = self.curve_add(&precomputation[i], &neg_rando); + // precomputation[4 * i] = self.curve_add(&precomputation[4 * i], &neg_rando); + // } + // for i in 1..4 { + // for j in 1..4 { + // precomputation[i + 4 * j] = + // self.curve_add(&precomputation[i], &precomputation[4 * j]); + // } + // } + // + // let four = self.constant(F::from_canonical_usize(4)); + // + // let zero = self.zero(); + // let mut result = rando_t; + // for (limb_n, limb_m) in limbs_n.into_iter().zip(limbs_m).rev() { + // result = self.curve_repeated_double(&result, 2); + // let index = self.mul_add(four, limb_m, limb_n); + // let r = self.random_access_curve_points(index, precomputation.clone()); + // let is_zero = self.is_equal(index, zero); + // let should_add = self.not(is_zero); + // result = self.curve_conditional_add(&result, &r, should_add); + // } + // let starting_point_multiplied = + // (0..C::ScalarField::BITS).fold(rando, |acc, _| acc.double()); + // let to_add = self.constant_affine_point(-starting_point_multiplied); + // result = self.curve_add(&result, &to_add); + // + // result + // } + + pub fn fixed_base_curve_mul( + &mut self, + base: &AffinePoint, + scalar: &NonNativeTarget, + ) -> AffinePointTarget { + let doubled_base = (0..scalar.value.limbs.len() * 8).scan(base.clone(), |acc, _| { + let tmp = acc.clone(); + for _ in 0..4 { + *acc = acc.double(); + } + Some(tmp) + }); + + let bits = self.split_nonnative_to_4_bit_limbs(scalar); + + // let rando = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine(); + let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); + let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( + &GenericHashOut::::to_bytes(&hash_0), + )); + let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); + let zero = self.zero(); + let mut result = self.constant_affine_point(rando.clone()); + for (limb, point) in bits.into_iter().zip(doubled_base) { + let mul_point = (0..16) + .scan(AffinePoint::ZERO, |acc, _| { + let tmp = acc.clone(); + *acc = (point + *acc).to_affine(); + Some(tmp) + }) + .map(|p| self.constant_affine_point(p)) + .collect::>(); + let is_zero = self.is_equal(limb, zero); + let should_add = self.not(is_zero); + let r = self.random_access_curve_points(limb, mul_point); + result = self.curve_conditional_add(&result, &r, should_add); + } + + let to_add = self.constant_affine_point(-rando); + self.curve_add(&result, &to_add) + } } #[cfg(test)] mod tests { + use std::str::FromStr; use anyhow::Result; + use num::BigUint; + use plonky2_field::field_types::PrimeField; + use plonky2_field::secp256k1_base::Secp256K1Base; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; use crate::curve::secp256k1::Secp256K1; use crate::field::field_types::Field; - use crate::iop::witness::PartialWitness; + use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; @@ -168,4 +277,38 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + + #[test] + fn test_fixed_base() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let mut pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let n = Secp256K1Scalar::from_canonical_usize(10); + let n = Secp256K1Scalar::rand(); + + let res = (CurveScalar(n) * g.to_projective()).to_affine(); + let res_expected = builder.constant_affine_point(res); + builder.curve_assert_valid(&res_expected); + + let n_target = builder.add_virtual_nonnative_target::(); + pw.set_biguint_target(&n_target.value, &n.to_canonical_biguint()); + + let res_target = builder.fixed_base_curve_mul(&g, &n_target); + builder.curve_assert_valid(&res_target); + + builder.connect_affine_point(&res_target, &res_expected); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } } diff --git a/plonky2/src/gadgets/curve_windowed_mul.rs b/plonky2/src/gadgets/curve_windowed_mul.rs index 879e1ade..46c663a0 100644 --- a/plonky2/src/gadgets/curve_windowed_mul.rs +++ b/plonky2/src/gadgets/curve_windowed_mul.rs @@ -44,12 +44,21 @@ impl, const D: usize> CircuitBuilder { access_index: Target, v: Vec>, ) -> AffinePointTarget { - let num_limbs = v[0].x.value.num_limbs(); + let num_limbs = C::BaseField::BITS / 32; + let zero = self.zero_u32(); let x_limbs: Vec> = (0..num_limbs) - .map(|i| v.iter().map(|p| p.x.value.limbs[i].0).collect()) + .map(|i| { + v.iter() + .map(|p| p.x.value.limbs.get(i).unwrap_or(&zero).0) + .collect() + }) .collect(); let y_limbs: Vec> = (0..num_limbs) - .map(|i| v.iter().map(|p| p.y.value.limbs[i].0).collect()) + .map(|i| { + v.iter() + .map(|p| p.y.value.limbs.get(i).unwrap_or(&zero).0) + .collect() + }) .collect(); let selected_x_limbs: Vec<_> = x_limbs diff --git a/plonky2/src/gadgets/glv.rs b/plonky2/src/gadgets/glv.rs index 89013dd1..b9a3a380 100644 --- a/plonky2/src/gadgets/glv.rs +++ b/plonky2/src/gadgets/glv.rs @@ -19,6 +19,7 @@ impl, const D: usize> CircuitBuilder { self.constant_nonnative(BETA) } + // TODO: Add decomposition check. pub fn decompose_secp256k1_scalar( &mut self, k: &NonNativeTarget, @@ -59,12 +60,24 @@ impl, const D: usize> CircuitBuilder { y: p.y.clone(), }; - let part1 = self.curve_scalar_mul_windowed(p, &k1); - let part1_neg = self.curve_conditional_neg(&part1, k1_neg); - let part2 = self.curve_scalar_mul_windowed(&sp, &k2); - let part2_neg = self.curve_conditional_neg(&part2, k2_neg); - - self.curve_add(&part1_neg, &part2_neg) + // let part1 = self.curve_scalar_mul_windowed(p, &k1); + // let part1_neg = self.curve_conditional_neg(&part1, k1_neg); + // let part2 = self.curve_scalar_mul_windowed(&sp, &k2); + // let part2_neg = self.curve_conditional_neg(&part2, k2_neg); + // + // self.curve_add(&part1_neg, &part2_neg) + // dbg!(k1.value.limbs.len()); + // dbg!(k2.value.limbs.len()); + let p_neg = self.curve_conditional_neg(&p, k1_neg); + let sp_neg = self.curve_conditional_neg(&sp, k2_neg); + // let yo = self.curve_scalar_mul_windowed(&p_neg, &k1); + // let ya = self.curve_scalar_mul_windowed(&sp_neg, &k2); + // dbg!(&yo); + // dbg!(&ya); + // self.connect_affine_point(&part1_neg, &yo); + // self.connect_affine_point(&part2_neg, &ya); + self.curve_msm(&p_neg, &sp_neg, &k1, &k2) + // self.curve_add(&yo, &ya) } } @@ -105,7 +118,7 @@ mod tests { use crate::curve::curve_types::{Curve, CurveScalar}; use crate::curve::glv::glv_mul; use crate::curve::secp256k1::Secp256K1; - use crate::iop::witness::PartialWitness; + use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; @@ -134,6 +147,43 @@ mod tests { let actual = builder.glv_mul(&randot, &scalar_target); builder.connect_affine_point(&expected, &actual); + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } + + #[test] + fn test_wtf() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let mut pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let rando = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let randot = builder.constant_affine_point(rando); + + let scalar = Secp256K1Scalar::rand(); + let scalar_target = builder.constant_nonnative(scalar); + + let tr = builder.add_virtual_bool_target(); + pw.set_bool_target(tr, false); + + let randotneg = builder.curve_conditional_neg(&randot, tr); + let y = builder.curve_scalar_mul_windowed(&randotneg, &scalar_target); + + let yy = builder.curve_scalar_mul_windowed(&randot, &scalar_target); + let yy = builder.curve_conditional_neg(&yy, tr); + + builder.connect_affine_point(&y, &yy); + + dbg!(builder.num_gates()); let data = builder.build::(); let proof = data.prove(pw).unwrap(); diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index 63f45fec..e0d05c24 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -70,7 +70,7 @@ pub struct CircuitBuilder, const D: usize> { marked_targets: Vec>, /// Generators used to generate the witness. - generators: Vec>>, + pub generators: Vec>>, constants_to_targets: HashMap, targets_to_constants: HashMap, From 850df4dfb1413b9a00555789caf70fb3dd0ac461 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 2 Mar 2022 11:16:32 +0100 Subject: [PATCH 41/56] Add fixed base file --- plonky2/src/gadgets/curve_fixed_base.rs | 103 ++++++++++++++++ plonky2/src/gadgets/curve_msm.rs | 149 +----------------------- plonky2/src/gadgets/mod.rs | 4 +- 3 files changed, 108 insertions(+), 148 deletions(-) create mode 100644 plonky2/src/gadgets/curve_fixed_base.rs diff --git a/plonky2/src/gadgets/curve_fixed_base.rs b/plonky2/src/gadgets/curve_fixed_base.rs new file mode 100644 index 00000000..3b826f78 --- /dev/null +++ b/plonky2/src/gadgets/curve_fixed_base.rs @@ -0,0 +1,103 @@ +use num::BigUint; +use plonky2_field::extension_field::Extendable; + +use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; +use crate::field::field_types::Field; +use crate::gadgets::curve::AffinePointTarget; +use crate::gadgets::nonnative::NonNativeTarget; +use crate::hash::hash_types::RichField; +use crate::hash::keccak::KeccakHash; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::config::{GenericHashOut, Hasher}; + +impl, const D: usize> CircuitBuilder { + pub fn fixed_base_curve_mul( + &mut self, + base: &AffinePoint, + scalar: &NonNativeTarget, + ) -> AffinePointTarget { + let doubled_base = (0..scalar.value.limbs.len() * 8).scan(base.clone(), |acc, _| { + let tmp = acc.clone(); + for _ in 0..4 { + *acc = acc.double(); + } + Some(tmp) + }); + + let bits = self.split_nonnative_to_4_bit_limbs(scalar); + + let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); + let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( + &GenericHashOut::::to_bytes(&hash_0), + )); + let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); + let zero = self.zero(); + let mut result = self.constant_affine_point(rando.clone()); + for (limb, point) in bits.into_iter().zip(doubled_base) { + let mul_point = (0..16) + .scan(AffinePoint::ZERO, |acc, _| { + let tmp = acc.clone(); + *acc = (point + *acc).to_affine(); + Some(tmp) + }) + .map(|p| self.constant_affine_point(p)) + .collect::>(); + let is_zero = self.is_equal(limb, zero); + let should_add = self.not(is_zero); + let r = self.random_access_curve_points(limb, mul_point); + result = self.curve_conditional_add(&result, &r, should_add); + } + + let to_add = self.constant_affine_point(-rando); + self.curve_add(&result, &to_add) + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2_field::field_types::PrimeField; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::secp256k1::Secp256K1; + use crate::field::field_types::Field; + use crate::iop::witness::{PartialWitness, Witness}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use crate::plonk::verifier::verify; + + #[test] + fn test_fixed_base() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let mut pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let n = Secp256K1Scalar::rand(); + + let res = (CurveScalar(n) * g.to_projective()).to_affine(); + let res_expected = builder.constant_affine_point(res); + builder.curve_assert_valid(&res_expected); + + let n_target = builder.add_virtual_nonnative_target::(); + pw.set_biguint_target(&n_target.value, &n.to_canonical_biguint()); + + let res_target = builder.fixed_base_curve_mul(&g, &n_target); + builder.curve_assert_valid(&res_target); + + builder.connect_affine_point(&res_target, &res_expected); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } +} diff --git a/plonky2/src/gadgets/curve_msm.rs b/plonky2/src/gadgets/curve_msm.rs index aa5a8d34..43a22da2 100644 --- a/plonky2/src/gadgets/curve_msm.rs +++ b/plonky2/src/gadgets/curve_msm.rs @@ -1,8 +1,7 @@ -use itertools::Itertools; use num::BigUint; use plonky2_field::extension_field::Extendable; -use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; +use crate::curve::curve_types::{Curve, CurveScalar}; use crate::field::field_types::Field; use crate::gadgets::curve::AffinePointTarget; use crate::gadgets::nonnative::NonNativeTarget; @@ -71,126 +70,18 @@ impl, const D: usize> CircuitBuilder { result } - - // pub fn quad_curve_msm( - // &mut self, - // points: [AffinePointTarget; 4], - // scalars: [NonNativeTarget; 4], - // ) -> AffinePointTarget { - // let limbs = scalars - // .iter() - // .map(|s| self.split_nonnative_to_bits(n)) - // .collect_vec(); - // - // let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); - // let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( - // &GenericHashOut::::to_bytes(&hash_0), - // )); - // let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); - // let rando_t = self.constant_affine_point(rando); - // let neg_rando = self.constant_affine_point(-rando); - // - // let mut precomputation = vec![points[0].clone(); 16]; - // for i in 0..4 { - // precomputation[1 << i] = points[i].clone(); - // for j in 1..1 << (i - 1) {} - // } - // let mut cur_p = rando_t.clone(); - // let mut cur_q = rando_t.clone(); - // for i in 0..4 { - // precomputation[i] = cur_p.clone(); - // precomputation[4 * i] = cur_q.clone(); - // cur_p = self.curve_add(&cur_p, p); - // cur_q = self.curve_add(&cur_q, q); - // } - // for i in 1..4 { - // precomputation[i] = self.curve_add(&precomputation[i], &neg_rando); - // precomputation[4 * i] = self.curve_add(&precomputation[4 * i], &neg_rando); - // } - // for i in 1..4 { - // for j in 1..4 { - // precomputation[i + 4 * j] = - // self.curve_add(&precomputation[i], &precomputation[4 * j]); - // } - // } - // - // let four = self.constant(F::from_canonical_usize(4)); - // - // let zero = self.zero(); - // let mut result = rando_t; - // for (limb_n, limb_m) in limbs_n.into_iter().zip(limbs_m).rev() { - // result = self.curve_repeated_double(&result, 2); - // let index = self.mul_add(four, limb_m, limb_n); - // let r = self.random_access_curve_points(index, precomputation.clone()); - // let is_zero = self.is_equal(index, zero); - // let should_add = self.not(is_zero); - // result = self.curve_conditional_add(&result, &r, should_add); - // } - // let starting_point_multiplied = - // (0..C::ScalarField::BITS).fold(rando, |acc, _| acc.double()); - // let to_add = self.constant_affine_point(-starting_point_multiplied); - // result = self.curve_add(&result, &to_add); - // - // result - // } - - pub fn fixed_base_curve_mul( - &mut self, - base: &AffinePoint, - scalar: &NonNativeTarget, - ) -> AffinePointTarget { - let doubled_base = (0..scalar.value.limbs.len() * 8).scan(base.clone(), |acc, _| { - let tmp = acc.clone(); - for _ in 0..4 { - *acc = acc.double(); - } - Some(tmp) - }); - - let bits = self.split_nonnative_to_4_bit_limbs(scalar); - - // let rando = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine(); - let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); - let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( - &GenericHashOut::::to_bytes(&hash_0), - )); - let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); - let zero = self.zero(); - let mut result = self.constant_affine_point(rando.clone()); - for (limb, point) in bits.into_iter().zip(doubled_base) { - let mul_point = (0..16) - .scan(AffinePoint::ZERO, |acc, _| { - let tmp = acc.clone(); - *acc = (point + *acc).to_affine(); - Some(tmp) - }) - .map(|p| self.constant_affine_point(p)) - .collect::>(); - let is_zero = self.is_equal(limb, zero); - let should_add = self.not(is_zero); - let r = self.random_access_curve_points(limb, mul_point); - result = self.curve_conditional_add(&result, &r, should_add); - } - - let to_add = self.constant_affine_point(-rando); - self.curve_add(&result, &to_add) - } } #[cfg(test)] mod tests { - use std::str::FromStr; use anyhow::Result; - use num::BigUint; - use plonky2_field::field_types::PrimeField; - use plonky2_field::secp256k1_base::Secp256K1Base; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; + use crate::curve::curve_types::{Curve, CurveScalar}; use crate::curve::secp256k1::Secp256K1; use crate::field::field_types::Field; - use crate::iop::witness::{PartialWitness, Witness}; + use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; @@ -277,38 +168,4 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } - - #[test] - fn test_fixed_base() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let mut pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let g = Secp256K1::GENERATOR_AFFINE; - let n = Secp256K1Scalar::from_canonical_usize(10); - let n = Secp256K1Scalar::rand(); - - let res = (CurveScalar(n) * g.to_projective()).to_affine(); - let res_expected = builder.constant_affine_point(res); - builder.curve_assert_valid(&res_expected); - - let n_target = builder.add_virtual_nonnative_target::(); - pw.set_biguint_target(&n_target.value, &n.to_canonical_biguint()); - - let res_target = builder.fixed_base_curve_mul(&g, &n_target); - builder.curve_assert_valid(&res_target); - - builder.connect_affine_point(&res_target, &res_expected); - - dbg!(builder.num_gates()); - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common) - } } diff --git a/plonky2/src/gadgets/mod.rs b/plonky2/src/gadgets/mod.rs index e35afeed..50fc0437 100644 --- a/plonky2/src/gadgets/mod.rs +++ b/plonky2/src/gadgets/mod.rs @@ -3,9 +3,9 @@ pub mod arithmetic_extension; pub mod arithmetic_u32; pub mod biguint; pub mod curve; -pub mod curve_windowed_mul; -// pub mod curve_msm; +pub mod curve_fixed_base; pub mod curve_msm; +pub mod curve_windowed_mul; pub mod ecdsa; pub mod glv; pub mod hash; From 7c70c46ca7c787c287ed20e021c63a52f42f8b7b Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 2 Mar 2022 13:19:31 +0100 Subject: [PATCH 42/56] Working GLV with MSM --- plonky2/src/gadgets/curve.rs | 5 +- plonky2/src/gadgets/curve_fixed_base.rs | 2 + plonky2/src/gadgets/curve_msm.rs | 89 ++++++++++++++++++++++--- plonky2/src/gadgets/glv.rs | 57 ++-------------- plonky2/src/gadgets/nonnative.rs | 2 +- plonky2/src/gadgets/split_nonnative.rs | 8 +-- plonky2/src/iop/generator.rs | 62 +++++++++++++++++ 7 files changed, 152 insertions(+), 73 deletions(-) diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index aa8c4112..a1fc3a8b 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -78,15 +78,12 @@ impl, const D: usize> CircuitBuilder { ) -> AffinePointTarget { let not_b = self.not(b); let neg = self.curve_neg(p); - let x_if_true = self.mul_nonnative_by_bool(&neg.x, b); let y_if_true = self.mul_nonnative_by_bool(&neg.y, b); - let x_if_false = self.mul_nonnative_by_bool(&p.x, not_b); let y_if_false = self.mul_nonnative_by_bool(&p.y, not_b); - let x = self.add_nonnative(&x_if_true, &x_if_false); let y = self.add_nonnative(&y_if_true, &y_if_false); - AffinePointTarget { x, y } + AffinePointTarget { x: p.x.clone(), y } } pub fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget { diff --git a/plonky2/src/gadgets/curve_fixed_base.rs b/plonky2/src/gadgets/curve_fixed_base.rs index 3b826f78..70def6bc 100644 --- a/plonky2/src/gadgets/curve_fixed_base.rs +++ b/plonky2/src/gadgets/curve_fixed_base.rs @@ -11,6 +11,8 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::config::{GenericHashOut, Hasher}; impl, const D: usize> CircuitBuilder { + /// Do windowed fixed-base scalar multiplication, using a 4-bit window. + // TODO: Benchmark other window sizes. pub fn fixed_base_curve_mul( &mut self, base: &AffinePoint, diff --git a/plonky2/src/gadgets/curve_msm.rs b/plonky2/src/gadgets/curve_msm.rs index 43a22da2..df13e8f3 100644 --- a/plonky2/src/gadgets/curve_msm.rs +++ b/plonky2/src/gadgets/curve_msm.rs @@ -3,8 +3,8 @@ use plonky2_field::extension_field::Extendable; use crate::curve::curve_types::{Curve, CurveScalar}; use crate::field::field_types::Field; +use crate::gadgets::biguint::BigUintTarget; use crate::gadgets::curve::AffinePointTarget; -use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::RichField; use crate::hash::keccak::KeccakHash; use crate::plonk::circuit_builder::CircuitBuilder; @@ -16,12 +16,13 @@ impl, const D: usize> CircuitBuilder { &mut self, p: &AffinePointTarget, q: &AffinePointTarget, - n: &NonNativeTarget, - m: &NonNativeTarget, + n: &BigUintTarget, + m: &BigUintTarget, ) -> AffinePointTarget { - let limbs_n = self.split_nonnative_to_2_bit_limbs(n); - let limbs_m = self.split_nonnative_to_2_bit_limbs(m); + let limbs_n = self.split_biguint_to_2_bit_limbs(n); + let limbs_m = self.split_biguint_to_2_bit_limbs(m); assert_eq!(limbs_n.len(), limbs_m.len()); + let num_limbs = limbs_n.len(); let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( @@ -63,8 +64,7 @@ impl, const D: usize> CircuitBuilder { let should_add = self.not(is_zero); result = self.curve_conditional_add(&result, &r, should_add); } - let starting_point_multiplied = - (0..C::ScalarField::BITS).fold(rando, |acc, _| acc.double()); + let starting_point_multiplied = (0..2 * num_limbs).fold(rando, |acc, _| acc.double()); let to_add = self.constant_affine_point(-starting_point_multiplied); result = self.curve_add(&result, &to_add); @@ -74,11 +74,14 @@ impl, const D: usize> CircuitBuilder { #[cfg(test)] mod tests { + use std::str::FromStr; use anyhow::Result; + use num::BigUint; + use plonky2_field::secp256k1_base::Secp256K1Base; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; use crate::curve::secp256k1::Secp256K1; use crate::field::field_types::Field; use crate::iop::witness::PartialWitness; @@ -115,7 +118,7 @@ mod tests { let n_target = builder.constant_nonnative(n); let m_target = builder.constant_nonnative(m); - let res_target = builder.curve_msm(&p_target, &q_target, &n_target, &m_target); + let res_target = builder.curve_msm(&p_target, &q_target, &n_target.value, &m_target.value); builder.curve_assert_valid(&res_target); builder.connect_affine_point(&res_target, &res_expected); @@ -168,4 +171,72 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + + #[test] + fn test_curve_lul() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let p = AffinePoint:: { + x: Secp256K1Base::from_biguint( + BigUint::from_str( + "95702873347299649035220040874584348285675823985309557645567012532974768144045", + ) + .unwrap(), + ), + y: Secp256K1Base::from_biguint( + BigUint::from_str( + "34849299245821426255020320369755722155634282348110887335812955146294938249053", + ) + .unwrap(), + ), + zero: false, + }; + let q = AffinePoint:: { + x: Secp256K1Base::from_biguint( + BigUint::from_str( + "66037057977021147605301350925941983227524093291368248236634649161657340356645", + ) + .unwrap(), + ), + y: Secp256K1Base::from_biguint( + BigUint::from_str( + "80942789991494769168550664638932185697635702317529676703644628861613896422610", + ) + .unwrap(), + ), + zero: false, + }; + + let n = BigUint::from_str("89874493710619023150462632713212469930").unwrap(); + let m = BigUint::from_str("76073901947022186525975758425319149118").unwrap(); + + let res = (CurveScalar(Secp256K1Scalar::from_biguint(n.clone())) * p.to_projective() + + CurveScalar(Secp256K1Scalar::from_biguint(m.clone())) * q.to_projective()) + .to_affine(); + let res_expected = builder.constant_affine_point(res); + builder.curve_assert_valid(&res_expected); + + let p_target = builder.constant_affine_point(p); + let q_target = builder.constant_affine_point(q); + let n_target = builder.constant_biguint(&n); + let m_target = builder.constant_biguint(&m); + + let res_target = builder.curve_msm(&p_target, &q_target, &n_target, &m_target); + builder.curve_assert_valid(&res_target); + + builder.connect_affine_point(&res_target, &res_expected); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + verify(proof, &data.verifier_only, &data.common) + } } diff --git a/plonky2/src/gadgets/glv.rs b/plonky2/src/gadgets/glv.rs index b9a3a380..e0a4cfaa 100644 --- a/plonky2/src/gadgets/glv.rs +++ b/plonky2/src/gadgets/glv.rs @@ -43,6 +43,8 @@ impl, const D: usize> CircuitBuilder { _phantom: PhantomData, }); + // debug_assert!(k1_raw + S * k2_raw == k); + (k1, k2, k1_neg, k2_neg) } @@ -60,24 +62,9 @@ impl, const D: usize> CircuitBuilder { y: p.y.clone(), }; - // let part1 = self.curve_scalar_mul_windowed(p, &k1); - // let part1_neg = self.curve_conditional_neg(&part1, k1_neg); - // let part2 = self.curve_scalar_mul_windowed(&sp, &k2); - // let part2_neg = self.curve_conditional_neg(&part2, k2_neg); - // - // self.curve_add(&part1_neg, &part2_neg) - // dbg!(k1.value.limbs.len()); - // dbg!(k2.value.limbs.len()); let p_neg = self.curve_conditional_neg(&p, k1_neg); let sp_neg = self.curve_conditional_neg(&sp, k2_neg); - // let yo = self.curve_scalar_mul_windowed(&p_neg, &k1); - // let ya = self.curve_scalar_mul_windowed(&sp_neg, &k2); - // dbg!(&yo); - // dbg!(&ya); - // self.connect_affine_point(&part1_neg, &yo); - // self.connect_affine_point(&part2_neg, &ya); - self.curve_msm(&p_neg, &sp_neg, &k1, &k2) - // self.curve_add(&yo, &ya) + self.curve_msm(&p_neg, &sp_neg, &k1.value, &k2.value) } } @@ -118,7 +105,7 @@ mod tests { use crate::curve::curve_types::{Curve, CurveScalar}; use crate::curve::glv::glv_mul; use crate::curve::secp256k1::Secp256K1; - use crate::iop::witness::{PartialWitness, Witness}; + use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; @@ -153,40 +140,4 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } - - #[test] - fn test_wtf() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let mut pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let rando = - (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); - let randot = builder.constant_affine_point(rando); - - let scalar = Secp256K1Scalar::rand(); - let scalar_target = builder.constant_nonnative(scalar); - - let tr = builder.add_virtual_bool_target(); - pw.set_bool_target(tr, false); - - let randotneg = builder.curve_conditional_neg(&randot, tr); - let y = builder.curve_scalar_mul_windowed(&randotneg, &scalar_target); - - let yy = builder.curve_scalar_mul_windowed(&randot, &scalar_target); - let yy = builder.curve_conditional_neg(&yy, tr); - - builder.connect_affine_point(&y, &yy); - - dbg!(builder.num_gates()); - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common) - } } diff --git a/plonky2/src/gadgets/nonnative.rs b/plonky2/src/gadgets/nonnative.rs index 910915d0..73bc0ad3 100644 --- a/plonky2/src/gadgets/nonnative.rs +++ b/plonky2/src/gadgets/nonnative.rs @@ -454,7 +454,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat let b_biguint = b.to_canonical_biguint(); let modulus = FF::order(); - let (diff_biguint, overflow) = if a_biguint > b_biguint { + let (diff_biguint, overflow) = if a_biguint >= b_biguint { (a_biguint - b_biguint, false) } else { (modulus + a_biguint - b_biguint, true) diff --git a/plonky2/src/gadgets/split_nonnative.rs b/plonky2/src/gadgets/split_nonnative.rs index 18fc0264..becf1177 100644 --- a/plonky2/src/gadgets/split_nonnative.rs +++ b/plonky2/src/gadgets/split_nonnative.rs @@ -35,12 +35,8 @@ impl, const D: usize> CircuitBuilder { .collect() } - pub fn split_nonnative_to_2_bit_limbs( - &mut self, - val: &NonNativeTarget, - ) -> Vec { - val.value - .limbs + pub fn split_biguint_to_2_bit_limbs(&mut self, val: &BigUintTarget) -> Vec { + val.limbs .iter() .flat_map(|&l| self.split_le_base::<4>(l.0, 16)) .collect() diff --git a/plonky2/src/iop/generator.rs b/plonky2/src/iop/generator.rs index 1569e889..4dcd11da 100644 --- a/plonky2/src/iop/generator.rs +++ b/plonky2/src/iop/generator.rs @@ -89,6 +89,68 @@ pub(crate) fn generate_partial_witness< } pending_generator_indices = next_pending_generator_indices; + // for t in [ + // Target::VirtualTarget { index: 57934 }, + // Target::VirtualTarget { index: 57935 }, + // Target::VirtualTarget { index: 57936 }, + // Target::VirtualTarget { index: 57937 }, + // Target::VirtualTarget { index: 57938 }, + // Target::VirtualTarget { index: 57939 }, + // Target::VirtualTarget { index: 57940 }, + // Target::VirtualTarget { index: 57941 }, + // ] { + // if let Some(v) = witness.try_get_target(t) { + // println!("a {}", v); + // } + // } + // for t in [ + // Target::VirtualTarget { index: 57952 }, + // Target::VirtualTarget { index: 57953 }, + // Target::VirtualTarget { index: 57954 }, + // Target::VirtualTarget { index: 57955 }, + // Target::VirtualTarget { index: 57956 }, + // Target::VirtualTarget { index: 57957 }, + // Target::VirtualTarget { index: 57958 }, + // Target::VirtualTarget { index: 57959 }, + // ] { + // if let Some(v) = witness.try_get_target(t) { + // println!("b {}", v); + // } + // } + // + // let t = Target::Wire(Wire { + // gate: 141_857, + // input: 8, + // }); + // if let Some(v) = witness.try_get_target(t) { + // println!("prod_exp {}", v); + // } + // let t = Target::Wire(Wire { + // gate: 141_863, + // input: 22, + // }); + // if let Some(v) = witness.try_get_target(t) { + // println!("prod act {}", v); + // } + // let t = Target::Wire(Wire { gate: 9, input: 3 }); + // if let Some(v) = witness.try_get_target(t) { + // println!("modulus {}", v); + // } + // let t = Target::VirtualTarget { index: 57_976 }; + // if let Some(v) = witness.try_get_target(t) { + // println!("overflow {}", v); + // } + // let t = Target::Wire(Wire { + // gate: 141_885, + // input: 8, + // }); + // if let Some(v) = witness.try_get_target(t) { + // println!("mod time ov {}", v); + // } + // let t = Target::VirtualTarget { index: 57_968 }; + // if let Some(v) = witness.try_get_target(t) { + // println!("prod {}", v); + // } } assert_eq!( From 2571862f00eb997c4bacc2118b731ce18a05e996 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 2 Mar 2022 13:31:16 +0100 Subject: [PATCH 43/56] Working GLV decomposition check --- plonky2/src/curve/glv.rs | 12 ++++++------ plonky2/src/gadgets/curve.rs | 12 ++++-------- plonky2/src/gadgets/glv.rs | 12 +++++++++--- plonky2/src/gadgets/nonnative.rs | 13 +++++++++++++ 4 files changed, 32 insertions(+), 17 deletions(-) diff --git a/plonky2/src/curve/glv.rs b/plonky2/src/curve/glv.rs index 11172be0..aeeb463e 100644 --- a/plonky2/src/curve/glv.rs +++ b/plonky2/src/curve/glv.rs @@ -8,14 +8,14 @@ use crate::curve::curve_msm::msm_parallel; use crate::curve::curve_types::{AffinePoint, ProjectivePoint}; use crate::curve::secp256k1::Secp256K1; -pub const BETA: Secp256K1Base = Secp256K1Base([ +pub const GLV_BETA: Secp256K1Base = Secp256K1Base([ 13923278643952681454, 11308619431505398165, 7954561588662645993, 8856726876819556112, ]); -const S: Secp256K1Scalar = Secp256K1Scalar([ +pub const GLV_S: Secp256K1Scalar = Secp256K1Scalar([ 16069571880186789234, 1310022930574435960, 11900229862571533402, @@ -52,7 +52,7 @@ pub fn decompose_secp256k1_scalar( let k1_raw = k - c1 * A1 - c2 * A2; let k2_raw = c1 * MINUS_B1 - c2 * B2; - debug_assert!(k1_raw + S * k2_raw == k); + debug_assert!(k1_raw + GLV_S * k2_raw == k); let two = BigUint::from_slice(&[2]); let k1_neg = k1_raw.to_canonical_biguint() > p.clone() / two.clone(); @@ -80,7 +80,7 @@ pub fn glv_mul(p: ProjectivePoint, k: Secp256K1Scalar) -> ProjectiveP let p_affine = p.to_affine(); let sp = AffinePoint:: { - x: p_affine.x * BETA, + x: p_affine.x * GLV_BETA, y: p_affine.y, zero: p_affine.zero, }; @@ -102,7 +102,7 @@ mod tests { use plonky2_field::secp256k1_scalar::Secp256K1Scalar; use crate::curve::curve_types::{Curve, CurveScalar}; - use crate::curve::glv::{decompose_secp256k1_scalar, glv_mul, S}; + use crate::curve::glv::{decompose_secp256k1_scalar, glv_mul, GLV_S}; use crate::curve::secp256k1::Secp256K1; #[test] @@ -113,7 +113,7 @@ mod tests { let m1 = if k1_neg { -one } else { one }; let m2 = if k2_neg { -one } else { one }; - assert!(k1 * m1 + S * k2 * m2 == k); + assert!(k1 * m1 + GLV_S * k2 * m2 == k); Ok(()) } diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index a1fc3a8b..e4e66a4e 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -76,14 +76,10 @@ impl, const D: usize> CircuitBuilder { p: &AffinePointTarget, b: BoolTarget, ) -> AffinePointTarget { - let not_b = self.not(b); - let neg = self.curve_neg(p); - let y_if_true = self.mul_nonnative_by_bool(&neg.y, b); - let y_if_false = self.mul_nonnative_by_bool(&p.y, not_b); - - let y = self.add_nonnative(&y_if_true, &y_if_false); - - AffinePointTarget { x: p.x.clone(), y } + AffinePointTarget { + x: p.x.clone(), + y: self.nonnative_conditional_neg(&p.y, b), + } } pub fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget { diff --git a/plonky2/src/gadgets/glv.rs b/plonky2/src/gadgets/glv.rs index e0a4cfaa..8447137d 100644 --- a/plonky2/src/gadgets/glv.rs +++ b/plonky2/src/gadgets/glv.rs @@ -4,7 +4,7 @@ use plonky2_field::extension_field::Extendable; use plonky2_field::secp256k1_base::Secp256K1Base; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; -use crate::curve::glv::{decompose_secp256k1_scalar, BETA}; +use crate::curve::glv::{decompose_secp256k1_scalar, GLV_BETA, GLV_S}; use crate::curve::secp256k1::Secp256K1; use crate::gadgets::curve::AffinePointTarget; use crate::gadgets::nonnative::NonNativeTarget; @@ -16,7 +16,7 @@ use crate::plonk::circuit_builder::CircuitBuilder; impl, const D: usize> CircuitBuilder { pub fn secp256k1_glv_beta(&mut self) -> NonNativeTarget { - self.constant_nonnative(BETA) + self.constant_nonnative(GLV_BETA) } // TODO: Add decomposition check. @@ -43,7 +43,13 @@ impl, const D: usize> CircuitBuilder { _phantom: PhantomData, }); - // debug_assert!(k1_raw + S * k2_raw == k); + // Check that `k1_raw + GLV_S * k2_raw == k`. + let k1_raw = self.nonnative_conditional_neg(&k1, k1_neg); + let k2_raw = self.nonnative_conditional_neg(&k2, k2_neg); + let s = self.constant_nonnative(GLV_S); + let mut should_be_k = self.mul_nonnative(&s, &k2_raw); + should_be_k = self.add_nonnative(&should_be_k, &k1_raw); + self.connect_nonnative(&should_be_k, &k); (k1, k2, k1_neg, k2_neg) } diff --git a/plonky2/src/gadgets/nonnative.rs b/plonky2/src/gadgets/nonnative.rs index 73bc0ad3..6c483a86 100644 --- a/plonky2/src/gadgets/nonnative.rs +++ b/plonky2/src/gadgets/nonnative.rs @@ -338,6 +338,19 @@ impl, const D: usize> CircuitBuilder { result } + + pub fn nonnative_conditional_neg( + &mut self, + x: &NonNativeTarget, + b: BoolTarget, + ) -> NonNativeTarget { + let not_b = self.not(b); + let neg = self.neg_nonnative(x); + let x_if_true = self.mul_nonnative_by_bool(&neg, b); + let x_if_false = self.mul_nonnative_by_bool(x, not_b); + + self.add_nonnative(&x_if_true, &x_if_false) + } } #[derive(Debug)] From c8d3335bce88da27995045644ab520d4ddcf372f Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 2 Mar 2022 13:37:01 +0100 Subject: [PATCH 44/56] ECDSA verification in 101k gates --- plonky2/src/gadgets/ecdsa.rs | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/plonky2/src/gadgets/ecdsa.rs b/plonky2/src/gadgets/ecdsa.rs index 5f4c4ff1..5bf43b04 100644 --- a/plonky2/src/gadgets/ecdsa.rs +++ b/plonky2/src/gadgets/ecdsa.rs @@ -1,6 +1,9 @@ use std::marker::PhantomData; +use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + use crate::curve::curve_types::Curve; +use crate::curve::secp256k1::Secp256K1; use crate::field::extension_field::Extendable; use crate::gadgets::curve::AffinePointTarget; use crate::gadgets::nonnative::NonNativeTarget; @@ -20,11 +23,11 @@ pub struct ECDSASignatureTarget { } impl, const D: usize> CircuitBuilder { - pub fn verify_message( + pub fn verify_message( &mut self, - msg: NonNativeTarget, - sig: ECDSASignatureTarget, - pk: ECDSAPublicKeyTarget, + msg: NonNativeTarget, + sig: ECDSASignatureTarget, + pk: ECDSAPublicKeyTarget, ) { let ECDSASignatureTarget { r, s } = sig; @@ -34,12 +37,11 @@ impl, const D: usize> CircuitBuilder { let u1 = self.mul_nonnative(&msg, &c); let u2 = self.mul_nonnative(&r, &c); - let g = self.constant_affine_point(C::GENERATOR_AFFINE); - let point1 = self.curve_scalar_mul_windowed(&g, &u1); - let point2 = self.curve_scalar_mul_windowed(&pk.0, &u2); + let point1 = self.fixed_base_curve_mul(&Secp256K1::GENERATOR_AFFINE, &u1); + let point2 = self.glv_mul(&pk.0, &u2); let point = self.curve_add(&point1, &point2); - let x = NonNativeTarget:: { + let x = NonNativeTarget:: { value: point.x.value, _phantom: PhantomData, }; From f6525ed11afd252c0d4faec9346323db2d519a2a Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 3 Mar 2022 04:15:15 +0100 Subject: [PATCH 45/56] Add wide config for ECDSA in < 2^16 gates --- plonky2/src/gadgets/ecdsa.rs | 10 +++++++++- plonky2/src/plonk/circuit_data.rs | 7 +++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/plonky2/src/gadgets/ecdsa.rs b/plonky2/src/gadgets/ecdsa.rs index 5bf43b04..1f4012a6 100644 --- a/plonky2/src/gadgets/ecdsa.rs +++ b/plonky2/src/gadgets/ecdsa.rs @@ -74,7 +74,15 @@ mod tests { type Curve = Secp256K1; - let config = CircuitConfig::standard_ecc_config(); + const WIDE: bool = true; + + let config = if WIDE { + // < 2^16 gates. + CircuitConfig::wide_ecc_config() + } else { + // < 2^17 gates. + CircuitConfig::standard_ecc_config() + }; let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index 3d4ee2df..34b38fcf 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -86,6 +86,13 @@ impl CircuitConfig { } } + pub fn wide_ecc_config() -> Self { + Self { + num_wires: 234, + ..Self::standard_recursion_config() + } + } + pub fn standard_recursion_zk_config() -> Self { CircuitConfig { zero_knowledge: true, From 90df0d9d3accce4e063affe1fb21336cca2fb33b Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 3 Mar 2022 04:19:17 +0100 Subject: [PATCH 46/56] Clippy --- plonky2/src/gadgets/curve_fixed_base.rs | 8 ++++---- plonky2/src/gadgets/glv.rs | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/plonky2/src/gadgets/curve_fixed_base.rs b/plonky2/src/gadgets/curve_fixed_base.rs index 70def6bc..4ba8d11e 100644 --- a/plonky2/src/gadgets/curve_fixed_base.rs +++ b/plonky2/src/gadgets/curve_fixed_base.rs @@ -18,8 +18,8 @@ impl, const D: usize> CircuitBuilder { base: &AffinePoint, scalar: &NonNativeTarget, ) -> AffinePointTarget { - let doubled_base = (0..scalar.value.limbs.len() * 8).scan(base.clone(), |acc, _| { - let tmp = acc.clone(); + let doubled_base = (0..scalar.value.limbs.len() * 8).scan(*base, |acc, _| { + let tmp = *acc; for _ in 0..4 { *acc = acc.double(); } @@ -34,11 +34,11 @@ impl, const D: usize> CircuitBuilder { )); let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); let zero = self.zero(); - let mut result = self.constant_affine_point(rando.clone()); + let mut result = self.constant_affine_point(rando); for (limb, point) in bits.into_iter().zip(doubled_base) { let mul_point = (0..16) .scan(AffinePoint::ZERO, |acc, _| { - let tmp = acc.clone(); + let tmp = *acc; *acc = (point + *acc).to_affine(); Some(tmp) }) diff --git a/plonky2/src/gadgets/glv.rs b/plonky2/src/gadgets/glv.rs index 8447137d..f0c4704b 100644 --- a/plonky2/src/gadgets/glv.rs +++ b/plonky2/src/gadgets/glv.rs @@ -49,7 +49,7 @@ impl, const D: usize> CircuitBuilder { let s = self.constant_nonnative(GLV_S); let mut should_be_k = self.mul_nonnative(&s, &k2_raw); should_be_k = self.add_nonnative(&should_be_k, &k1_raw); - self.connect_nonnative(&should_be_k, &k); + self.connect_nonnative(&should_be_k, k); (k1, k2, k1_neg, k2_neg) } @@ -68,7 +68,7 @@ impl, const D: usize> CircuitBuilder { y: p.y.clone(), }; - let p_neg = self.curve_conditional_neg(&p, k1_neg); + let p_neg = self.curve_conditional_neg(p, k1_neg); let sp_neg = self.curve_conditional_neg(&sp, k2_neg); self.curve_msm(&p_neg, &sp_neg, &k1.value, &k2.value) } From 47523c086a783c202d9e2da6b5ada7a8b853dbbd Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 3 Mar 2022 04:43:04 +0100 Subject: [PATCH 47/56] Minor --- plonky2/src/gadgets/curve_fixed_base.rs | 6 +++--- plonky2/src/gadgets/ecdsa.rs | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/plonky2/src/gadgets/curve_fixed_base.rs b/plonky2/src/gadgets/curve_fixed_base.rs index 4ba8d11e..0cbd7a9e 100644 --- a/plonky2/src/gadgets/curve_fixed_base.rs +++ b/plonky2/src/gadgets/curve_fixed_base.rs @@ -15,10 +15,10 @@ impl, const D: usize> CircuitBuilder { // TODO: Benchmark other window sizes. pub fn fixed_base_curve_mul( &mut self, - base: &AffinePoint, + base: AffinePoint, scalar: &NonNativeTarget, ) -> AffinePointTarget { - let doubled_base = (0..scalar.value.limbs.len() * 8).scan(*base, |acc, _| { + let doubled_base = (0..scalar.value.limbs.len() * 8).scan(base, |acc, _| { let tmp = *acc; for _ in 0..4 { *acc = acc.double(); @@ -91,7 +91,7 @@ mod tests { let n_target = builder.add_virtual_nonnative_target::(); pw.set_biguint_target(&n_target.value, &n.to_canonical_biguint()); - let res_target = builder.fixed_base_curve_mul(&g, &n_target); + let res_target = builder.fixed_base_curve_mul(g, &n_target); builder.curve_assert_valid(&res_target); builder.connect_affine_point(&res_target, &res_expected); diff --git a/plonky2/src/gadgets/ecdsa.rs b/plonky2/src/gadgets/ecdsa.rs index 1f4012a6..672700c8 100644 --- a/plonky2/src/gadgets/ecdsa.rs +++ b/plonky2/src/gadgets/ecdsa.rs @@ -37,7 +37,7 @@ impl, const D: usize> CircuitBuilder { let u1 = self.mul_nonnative(&msg, &c); let u2 = self.mul_nonnative(&r, &c); - let point1 = self.fixed_base_curve_mul(&Secp256K1::GENERATOR_AFFINE, &u1); + let point1 = self.fixed_base_curve_mul(Secp256K1::GENERATOR_AFFINE, &u1); let point2 = self.glv_mul(&pk.0, &u2); let point = self.curve_add(&point1, &point2); From 18e341ff18becda2f6fe942bce664ccf1264d37b Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 3 Mar 2022 08:06:21 +0100 Subject: [PATCH 48/56] Comments --- plonky2/src/gadgets/curve_fixed_base.rs | 16 ++++++++++------ plonky2/src/gadgets/curve_msm.rs | 6 +++++- plonky2/src/gadgets/glv.rs | 1 - 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/plonky2/src/gadgets/curve_fixed_base.rs b/plonky2/src/gadgets/curve_fixed_base.rs index 0cbd7a9e..3c470044 100644 --- a/plonky2/src/gadgets/curve_fixed_base.rs +++ b/plonky2/src/gadgets/curve_fixed_base.rs @@ -12,13 +12,13 @@ use crate::plonk::config::{GenericHashOut, Hasher}; impl, const D: usize> CircuitBuilder { /// Do windowed fixed-base scalar multiplication, using a 4-bit window. - // TODO: Benchmark other window sizes. pub fn fixed_base_curve_mul( &mut self, base: AffinePoint, scalar: &NonNativeTarget, ) -> AffinePointTarget { - let doubled_base = (0..scalar.value.limbs.len() * 8).scan(base, |acc, _| { + // Holds `(16^i) * base` for `i=0..scalar.value.limbs.len() * 8`. + let scaled_base = (0..scalar.value.limbs.len() * 8).scan(base, |acc, _| { let tmp = *acc; for _ in 0..4 { *acc = acc.double(); @@ -26,17 +26,20 @@ impl, const D: usize> CircuitBuilder { Some(tmp) }); - let bits = self.split_nonnative_to_4_bit_limbs(scalar); + let limbs = self.split_nonnative_to_4_bit_limbs(scalar); let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( &GenericHashOut::::to_bytes(&hash_0), )); let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); + let zero = self.zero(); let mut result = self.constant_affine_point(rando); - for (limb, point) in bits.into_iter().zip(doubled_base) { - let mul_point = (0..16) + // `s * P = sum s_i * P_i` with `P_i = (16^i) * P` and `s = sum s_i * (16^i)`. + for (limb, point) in limbs.into_iter().zip(scaled_base) { + // Holds `t * P_i` for `p=0..16`. + let muls_point = (0..16) .scan(AffinePoint::ZERO, |acc, _| { let tmp = *acc; *acc = (point + *acc).to_affine(); @@ -46,7 +49,8 @@ impl, const D: usize> CircuitBuilder { .collect::>(); let is_zero = self.is_equal(limb, zero); let should_add = self.not(is_zero); - let r = self.random_access_curve_points(limb, mul_point); + // `r = s_i * P_i` + let r = self.random_access_curve_points(limb, muls_point); result = self.curve_conditional_add(&result, &r, should_add); } diff --git a/plonky2/src/gadgets/curve_msm.rs b/plonky2/src/gadgets/curve_msm.rs index df13e8f3..5d505c4d 100644 --- a/plonky2/src/gadgets/curve_msm.rs +++ b/plonky2/src/gadgets/curve_msm.rs @@ -11,7 +11,10 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::config::{GenericHashOut, Hasher}; impl, const D: usize> CircuitBuilder { - /// Computes `n*p + m*q`. + /// Computes `n*p + m*q` using windowed MSM, with a 2-bit window. + /// See Algorithm 9.23 in Handbook of Elliptic and Hyperelliptic Curve Cryptography for a + /// description. + /// Note: Doesn't work if `p == q`. pub fn curve_msm( &mut self, p: &AffinePointTarget, @@ -32,6 +35,7 @@ impl, const D: usize> CircuitBuilder { let rando_t = self.constant_affine_point(rando); let neg_rando = self.constant_affine_point(-rando); + // Precomputes `precomputation[i + 4*j] = i*p + j*q` for `i,j=0..4`. let mut precomputation = vec![p.clone(); 16]; let mut cur_p = rando_t.clone(); let mut cur_q = rando_t.clone(); diff --git a/plonky2/src/gadgets/glv.rs b/plonky2/src/gadgets/glv.rs index f0c4704b..4bc3efd6 100644 --- a/plonky2/src/gadgets/glv.rs +++ b/plonky2/src/gadgets/glv.rs @@ -19,7 +19,6 @@ impl, const D: usize> CircuitBuilder { self.constant_nonnative(GLV_BETA) } - // TODO: Add decomposition check. pub fn decompose_secp256k1_scalar( &mut self, k: &NonNativeTarget, From 5febea778be1efb0f53225dff0bca327a561e777 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 3 Mar 2022 08:14:56 +0100 Subject: [PATCH 49/56] Fixes --- plonky2/src/gadgets/curve_fixed_base.rs | 4 +- plonky2/src/gadgets/curve_msm.rs | 116 +----------------------- plonky2/src/iop/generator.rs | 62 ------------- 3 files changed, 3 insertions(+), 179 deletions(-) diff --git a/plonky2/src/gadgets/curve_fixed_base.rs b/plonky2/src/gadgets/curve_fixed_base.rs index 3c470044..e248d951 100644 --- a/plonky2/src/gadgets/curve_fixed_base.rs +++ b/plonky2/src/gadgets/curve_fixed_base.rs @@ -11,7 +11,7 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::config::{GenericHashOut, Hasher}; impl, const D: usize> CircuitBuilder { - /// Do windowed fixed-base scalar multiplication, using a 4-bit window. + /// Compute windowed fixed-base scalar multiplication, using a 4-bit window. pub fn fixed_base_curve_mul( &mut self, base: AffinePoint, @@ -38,7 +38,7 @@ impl, const D: usize> CircuitBuilder { let mut result = self.constant_affine_point(rando); // `s * P = sum s_i * P_i` with `P_i = (16^i) * P` and `s = sum s_i * (16^i)`. for (limb, point) in limbs.into_iter().zip(scaled_base) { - // Holds `t * P_i` for `p=0..16`. + // `muls_point[t] = t * P_i` for `t=0..16`. let muls_point = (0..16) .scan(AffinePoint::ZERO, |acc, _| { let tmp = *acc; diff --git a/plonky2/src/gadgets/curve_msm.rs b/plonky2/src/gadgets/curve_msm.rs index 5d505c4d..99aa0f36 100644 --- a/plonky2/src/gadgets/curve_msm.rs +++ b/plonky2/src/gadgets/curve_msm.rs @@ -78,14 +78,10 @@ impl, const D: usize> CircuitBuilder { #[cfg(test)] mod tests { - use std::str::FromStr; - use anyhow::Result; - use num::BigUint; - use plonky2_field::secp256k1_base::Secp256K1Base; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; + use crate::curve::curve_types::{Curve, CurveScalar}; use crate::curve::secp256k1::Secp256K1; use crate::field::field_types::Field; use crate::iop::witness::PartialWitness; @@ -133,114 +129,4 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } - - #[test] - fn test_naive_msm() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let p = - (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); - let q = - (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); - let n = Secp256K1Scalar::rand(); - let m = Secp256K1Scalar::rand(); - - let res = - (CurveScalar(n) * p.to_projective() + CurveScalar(m) * q.to_projective()).to_affine(); - let res_expected = builder.constant_affine_point(res); - builder.curve_assert_valid(&res_expected); - - let p_target = builder.constant_affine_point(p); - let q_target = builder.constant_affine_point(q); - let n_target = builder.constant_nonnative(n); - let m_target = builder.constant_nonnative(m); - - let res0_target = builder.curve_scalar_mul_windowed(&p_target, &n_target); - let res1_target = builder.curve_scalar_mul_windowed(&q_target, &m_target); - let res_target = builder.curve_add(&res0_target, &res1_target); - builder.curve_assert_valid(&res_target); - - builder.connect_affine_point(&res_target, &res_expected); - - dbg!(builder.num_gates()); - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - fn test_curve_lul() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let p = AffinePoint:: { - x: Secp256K1Base::from_biguint( - BigUint::from_str( - "95702873347299649035220040874584348285675823985309557645567012532974768144045", - ) - .unwrap(), - ), - y: Secp256K1Base::from_biguint( - BigUint::from_str( - "34849299245821426255020320369755722155634282348110887335812955146294938249053", - ) - .unwrap(), - ), - zero: false, - }; - let q = AffinePoint:: { - x: Secp256K1Base::from_biguint( - BigUint::from_str( - "66037057977021147605301350925941983227524093291368248236634649161657340356645", - ) - .unwrap(), - ), - y: Secp256K1Base::from_biguint( - BigUint::from_str( - "80942789991494769168550664638932185697635702317529676703644628861613896422610", - ) - .unwrap(), - ), - zero: false, - }; - - let n = BigUint::from_str("89874493710619023150462632713212469930").unwrap(); - let m = BigUint::from_str("76073901947022186525975758425319149118").unwrap(); - - let res = (CurveScalar(Secp256K1Scalar::from_biguint(n.clone())) * p.to_projective() - + CurveScalar(Secp256K1Scalar::from_biguint(m.clone())) * q.to_projective()) - .to_affine(); - let res_expected = builder.constant_affine_point(res); - builder.curve_assert_valid(&res_expected); - - let p_target = builder.constant_affine_point(p); - let q_target = builder.constant_affine_point(q); - let n_target = builder.constant_biguint(&n); - let m_target = builder.constant_biguint(&m); - - let res_target = builder.curve_msm(&p_target, &q_target, &n_target, &m_target); - builder.curve_assert_valid(&res_target); - - builder.connect_affine_point(&res_target, &res_expected); - - dbg!(builder.num_gates()); - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common) - } } diff --git a/plonky2/src/iop/generator.rs b/plonky2/src/iop/generator.rs index 4dcd11da..1569e889 100644 --- a/plonky2/src/iop/generator.rs +++ b/plonky2/src/iop/generator.rs @@ -89,68 +89,6 @@ pub(crate) fn generate_partial_witness< } pending_generator_indices = next_pending_generator_indices; - // for t in [ - // Target::VirtualTarget { index: 57934 }, - // Target::VirtualTarget { index: 57935 }, - // Target::VirtualTarget { index: 57936 }, - // Target::VirtualTarget { index: 57937 }, - // Target::VirtualTarget { index: 57938 }, - // Target::VirtualTarget { index: 57939 }, - // Target::VirtualTarget { index: 57940 }, - // Target::VirtualTarget { index: 57941 }, - // ] { - // if let Some(v) = witness.try_get_target(t) { - // println!("a {}", v); - // } - // } - // for t in [ - // Target::VirtualTarget { index: 57952 }, - // Target::VirtualTarget { index: 57953 }, - // Target::VirtualTarget { index: 57954 }, - // Target::VirtualTarget { index: 57955 }, - // Target::VirtualTarget { index: 57956 }, - // Target::VirtualTarget { index: 57957 }, - // Target::VirtualTarget { index: 57958 }, - // Target::VirtualTarget { index: 57959 }, - // ] { - // if let Some(v) = witness.try_get_target(t) { - // println!("b {}", v); - // } - // } - // - // let t = Target::Wire(Wire { - // gate: 141_857, - // input: 8, - // }); - // if let Some(v) = witness.try_get_target(t) { - // println!("prod_exp {}", v); - // } - // let t = Target::Wire(Wire { - // gate: 141_863, - // input: 22, - // }); - // if let Some(v) = witness.try_get_target(t) { - // println!("prod act {}", v); - // } - // let t = Target::Wire(Wire { gate: 9, input: 3 }); - // if let Some(v) = witness.try_get_target(t) { - // println!("modulus {}", v); - // } - // let t = Target::VirtualTarget { index: 57_976 }; - // if let Some(v) = witness.try_get_target(t) { - // println!("overflow {}", v); - // } - // let t = Target::Wire(Wire { - // gate: 141_885, - // input: 8, - // }); - // if let Some(v) = witness.try_get_target(t) { - // println!("mod time ov {}", v); - // } - // let t = Target::VirtualTarget { index: 57_968 }; - // if let Some(v) = witness.try_get_target(t) { - // println!("prod {}", v); - // } } assert_eq!( From 3a68a458c4638f575625b9db55377902981b8bb0 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 3 Mar 2022 08:44:27 +0100 Subject: [PATCH 50/56] Ignore large tests --- plonky2/src/gadgets/curve_fixed_base.rs | 1 + plonky2/src/gadgets/curve_msm.rs | 1 + 2 files changed, 2 insertions(+) diff --git a/plonky2/src/gadgets/curve_fixed_base.rs b/plonky2/src/gadgets/curve_fixed_base.rs index e248d951..f28e45d1 100644 --- a/plonky2/src/gadgets/curve_fixed_base.rs +++ b/plonky2/src/gadgets/curve_fixed_base.rs @@ -75,6 +75,7 @@ mod tests { use crate::plonk::verifier::verify; #[test] + #[ignore] fn test_fixed_base() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; diff --git a/plonky2/src/gadgets/curve_msm.rs b/plonky2/src/gadgets/curve_msm.rs index 99aa0f36..8d019b3a 100644 --- a/plonky2/src/gadgets/curve_msm.rs +++ b/plonky2/src/gadgets/curve_msm.rs @@ -91,6 +91,7 @@ mod tests { use crate::plonk::verifier::verify; #[test] + #[ignore] fn test_curve_msm() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; From 954eaf16f26dd229f558a438c746ffe949a51147 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Sat, 5 Mar 2022 02:36:08 +0100 Subject: [PATCH 51/56] PR feedback --- plonky2/src/gadgets/curve_msm.rs | 12 ++++++------ plonky2/src/gadgets/ecdsa.rs | 26 +++++++++++++------------- plonky2/src/gadgets/glv.rs | 2 +- plonky2/src/gadgets/split_nonnative.rs | 8 ++++++-- 4 files changed, 26 insertions(+), 22 deletions(-) diff --git a/plonky2/src/gadgets/curve_msm.rs b/plonky2/src/gadgets/curve_msm.rs index 8d019b3a..fba7c229 100644 --- a/plonky2/src/gadgets/curve_msm.rs +++ b/plonky2/src/gadgets/curve_msm.rs @@ -3,8 +3,8 @@ use plonky2_field::extension_field::Extendable; use crate::curve::curve_types::{Curve, CurveScalar}; use crate::field::field_types::Field; -use crate::gadgets::biguint::BigUintTarget; use crate::gadgets::curve::AffinePointTarget; +use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::RichField; use crate::hash::keccak::KeccakHash; use crate::plonk::circuit_builder::CircuitBuilder; @@ -19,11 +19,11 @@ impl, const D: usize> CircuitBuilder { &mut self, p: &AffinePointTarget, q: &AffinePointTarget, - n: &BigUintTarget, - m: &BigUintTarget, + n: &NonNativeTarget, + m: &NonNativeTarget, ) -> AffinePointTarget { - let limbs_n = self.split_biguint_to_2_bit_limbs(n); - let limbs_m = self.split_biguint_to_2_bit_limbs(m); + let limbs_n = self.split_nonnative_to_2_bit_limbs(n); + let limbs_m = self.split_nonnative_to_2_bit_limbs(m); assert_eq!(limbs_n.len(), limbs_m.len()); let num_limbs = limbs_n.len(); @@ -119,7 +119,7 @@ mod tests { let n_target = builder.constant_nonnative(n); let m_target = builder.constant_nonnative(m); - let res_target = builder.curve_msm(&p_target, &q_target, &n_target.value, &m_target.value); + let res_target = builder.curve_msm(&p_target, &q_target, &n_target, &m_target); builder.curve_assert_valid(&res_target); builder.connect_affine_point(&res_target, &res_expected); diff --git a/plonky2/src/gadgets/ecdsa.rs b/plonky2/src/gadgets/ecdsa.rs index 672700c8..a376e56a 100644 --- a/plonky2/src/gadgets/ecdsa.rs +++ b/plonky2/src/gadgets/ecdsa.rs @@ -65,25 +65,13 @@ mod tests { use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::plonk::verifier::verify; - #[test] - #[ignore] - fn test_ecdsa_circuit() -> Result<()> { + fn test_ecdsa_circuit_with_config(config: CircuitConfig) -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; type Curve = Secp256K1; - const WIDE: bool = true; - - let config = if WIDE { - // < 2^16 gates. - CircuitConfig::wide_ecc_config() - } else { - // < 2^17 gates. - CircuitConfig::standard_ecc_config() - }; - let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -112,4 +100,16 @@ mod tests { let proof = data.prove(pw).unwrap(); verify(proof, &data.verifier_only, &data.common) } + + #[test] + #[ignore] + fn test_ecdsa_circuit_narrow() -> Result<()> { + test_ecdsa_circuit_with_config(CircuitConfig::standard_ecc_config()) + } + + #[test] + #[ignore] + fn test_ecdsa_circuit_wide() -> Result<()> { + test_ecdsa_circuit_with_config(CircuitConfig::wide_ecc_config()) + } } diff --git a/plonky2/src/gadgets/glv.rs b/plonky2/src/gadgets/glv.rs index 4bc3efd6..8a0179ec 100644 --- a/plonky2/src/gadgets/glv.rs +++ b/plonky2/src/gadgets/glv.rs @@ -69,7 +69,7 @@ impl, const D: usize> CircuitBuilder { let p_neg = self.curve_conditional_neg(p, k1_neg); let sp_neg = self.curve_conditional_neg(&sp, k2_neg); - self.curve_msm(&p_neg, &sp_neg, &k1.value, &k2.value) + self.curve_msm(&p_neg, &sp_neg, &k1, &k2) } } diff --git a/plonky2/src/gadgets/split_nonnative.rs b/plonky2/src/gadgets/split_nonnative.rs index becf1177..18fc0264 100644 --- a/plonky2/src/gadgets/split_nonnative.rs +++ b/plonky2/src/gadgets/split_nonnative.rs @@ -35,8 +35,12 @@ impl, const D: usize> CircuitBuilder { .collect() } - pub fn split_biguint_to_2_bit_limbs(&mut self, val: &BigUintTarget) -> Vec { - val.limbs + pub fn split_nonnative_to_2_bit_limbs( + &mut self, + val: &NonNativeTarget, + ) -> Vec { + val.value + .limbs .iter() .flat_map(|&l| self.split_le_base::<4>(l.0, 16)) .collect() From 2e5c2e89843aced81b6c711a0cdbd93b5a1318f6 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Tue, 15 Mar 2022 16:55:08 +0100 Subject: [PATCH 52/56] Add ecdsa module --- Cargo.toml | 2 +- ecdsa/src/gadgets/mod.rs | 9 + plonky2/src/curve/curve_adds.rs | 158 ----- plonky2/src/curve/curve_msm.rs | 265 -------- plonky2/src/curve/curve_multiplication.rs | 99 --- plonky2/src/curve/curve_summation.rs | 239 ------- plonky2/src/curve/curve_types.rs | 282 --------- plonky2/src/curve/ecdsa.rs | 78 --- plonky2/src/curve/glv.rs | 136 ---- plonky2/src/curve/mod.rs | 8 - plonky2/src/curve/secp256k1.rs | 101 --- plonky2/src/gadgets/biguint.rs | 418 ------------ plonky2/src/gadgets/curve.rs | 434 ------------- plonky2/src/gadgets/curve_fixed_base.rs | 110 ---- plonky2/src/gadgets/curve_msm.rs | 133 ---- plonky2/src/gadgets/curve_windowed_mul.rs | 224 ------- plonky2/src/gadgets/ecdsa.rs | 115 ---- plonky2/src/gadgets/glv.rs | 148 ----- plonky2/src/gadgets/mod.rs | 9 - plonky2/src/gadgets/nonnative.rs | 732 ---------------------- plonky2/src/gadgets/split_nonnative.rs | 109 ---- plonky2/src/iop/generator.rs | 20 +- plonky2/src/iop/witness.rs | 40 +- plonky2/src/lib.rs | 1 - 24 files changed, 12 insertions(+), 3858 deletions(-) create mode 100644 ecdsa/src/gadgets/mod.rs delete mode 100644 plonky2/src/curve/curve_adds.rs delete mode 100644 plonky2/src/curve/curve_msm.rs delete mode 100644 plonky2/src/curve/curve_multiplication.rs delete mode 100644 plonky2/src/curve/curve_summation.rs delete mode 100644 plonky2/src/curve/curve_types.rs delete mode 100644 plonky2/src/curve/ecdsa.rs delete mode 100644 plonky2/src/curve/glv.rs delete mode 100644 plonky2/src/curve/mod.rs delete mode 100644 plonky2/src/curve/secp256k1.rs delete mode 100644 plonky2/src/gadgets/biguint.rs delete mode 100644 plonky2/src/gadgets/curve.rs delete mode 100644 plonky2/src/gadgets/curve_fixed_base.rs delete mode 100644 plonky2/src/gadgets/curve_msm.rs delete mode 100644 plonky2/src/gadgets/curve_windowed_mul.rs delete mode 100644 plonky2/src/gadgets/ecdsa.rs delete mode 100644 plonky2/src/gadgets/glv.rs delete mode 100644 plonky2/src/gadgets/nonnative.rs delete mode 100644 plonky2/src/gadgets/split_nonnative.rs diff --git a/Cargo.toml b/Cargo.toml index cc070d96..00a4b28f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["field", "insertion", "plonky2", "starky", "system_zero", "util", "waksman"] +members = ["field", "insertion", "plonky2", "starky", "system_zero", "util", "waksman", "ecdsa"] [profile.release] opt-level = 3 diff --git a/ecdsa/src/gadgets/mod.rs b/ecdsa/src/gadgets/mod.rs new file mode 100644 index 00000000..35b10100 --- /dev/null +++ b/ecdsa/src/gadgets/mod.rs @@ -0,0 +1,9 @@ +pub mod biguint; +pub mod curve; +pub mod curve_fixed_base; +pub mod curve_msm; +pub mod curve_windowed_mul; +pub mod ecdsa; +pub mod glv; +pub mod nonnative; +pub mod split_nonnative; diff --git a/plonky2/src/curve/curve_adds.rs b/plonky2/src/curve/curve_adds.rs deleted file mode 100644 index 98dbc697..00000000 --- a/plonky2/src/curve/curve_adds.rs +++ /dev/null @@ -1,158 +0,0 @@ -use std::ops::Add; - -use plonky2_field::field_types::Field; -use plonky2_field::ops::Square; - -use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; - -impl Add> for ProjectivePoint { - type Output = ProjectivePoint; - - fn add(self, rhs: ProjectivePoint) -> Self::Output { - let ProjectivePoint { - x: x1, - y: y1, - z: z1, - } = self; - let ProjectivePoint { - x: x2, - y: y2, - z: z2, - } = rhs; - - if z1 == C::BaseField::ZERO { - return rhs; - } - if z2 == C::BaseField::ZERO { - return self; - } - - let x1z2 = x1 * z2; - let y1z2 = y1 * z2; - let x2z1 = x2 * z1; - let y2z1 = y2 * z1; - - // Check if we're doubling or adding inverses. - if x1z2 == x2z1 { - if y1z2 == y2z1 { - // TODO: inline to avoid redundant muls. - return self.double(); - } - if y1z2 == -y2z1 { - return ProjectivePoint::ZERO; - } - } - - // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/add-1998-cmo-2 - let z1z2 = z1 * z2; - let u = y2z1 - y1z2; - let uu = u.square(); - let v = x2z1 - x1z2; - let vv = v.square(); - let vvv = v * vv; - let r = vv * x1z2; - let a = uu * z1z2 - vvv - r.double(); - let x3 = v * a; - let y3 = u * (r - a) - vvv * y1z2; - let z3 = vvv * z1z2; - ProjectivePoint::nonzero(x3, y3, z3) - } -} - -impl Add> for ProjectivePoint { - type Output = ProjectivePoint; - - fn add(self, rhs: AffinePoint) -> Self::Output { - let ProjectivePoint { - x: x1, - y: y1, - z: z1, - } = self; - let AffinePoint { - x: x2, - y: y2, - zero: zero2, - } = rhs; - - if z1 == C::BaseField::ZERO { - return rhs.to_projective(); - } - if zero2 { - return self; - } - - let x2z1 = x2 * z1; - let y2z1 = y2 * z1; - - // Check if we're doubling or adding inverses. - if x1 == x2z1 { - if y1 == y2z1 { - // TODO: inline to avoid redundant muls. - return self.double(); - } - if y1 == -y2z1 { - return ProjectivePoint::ZERO; - } - } - - // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/madd-1998-cmo - let u = y2z1 - y1; - let uu = u.square(); - let v = x2z1 - x1; - let vv = v.square(); - let vvv = v * vv; - let r = vv * x1; - let a = uu * z1 - vvv - r.double(); - let x3 = v * a; - let y3 = u * (r - a) - vvv * y1; - let z3 = vvv * z1; - ProjectivePoint::nonzero(x3, y3, z3) - } -} - -impl Add> for AffinePoint { - type Output = ProjectivePoint; - - fn add(self, rhs: AffinePoint) -> Self::Output { - let AffinePoint { - x: x1, - y: y1, - zero: zero1, - } = self; - let AffinePoint { - x: x2, - y: y2, - zero: zero2, - } = rhs; - - if zero1 { - return rhs.to_projective(); - } - if zero2 { - return self.to_projective(); - } - - // Check if we're doubling or adding inverses. - if x1 == x2 { - if y1 == y2 { - return self.to_projective().double(); - } - if y1 == -y2 { - return ProjectivePoint::ZERO; - } - } - - // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/mmadd-1998-cmo - let u = y2 - y1; - let uu = u.square(); - let v = x2 - x1; - let vv = v.square(); - let vvv = v * vv; - let r = vv * x1; - let a = uu - vvv - r.double(); - let x3 = v * a; - let y3 = u * (r - a) - vvv * y1; - let z3 = vvv; - ProjectivePoint::nonzero(x3, y3, z3) - } -} diff --git a/plonky2/src/curve/curve_msm.rs b/plonky2/src/curve/curve_msm.rs deleted file mode 100644 index 4c274c1c..00000000 --- a/plonky2/src/curve/curve_msm.rs +++ /dev/null @@ -1,265 +0,0 @@ -use itertools::Itertools; -use plonky2_field::field_types::Field; -use plonky2_field::field_types::PrimeField; -use rayon::prelude::*; - -use crate::curve::curve_summation::affine_multisummation_best; -use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; - -/// In Yao's method, we compute an affine summation for each digit. In a parallel setting, it would -/// be easiest to assign individual summations to threads, but this would be sub-optimal because -/// multi-summations can be more efficient than repeating individual summations (see -/// `affine_multisummation_best`). Thus we divide digits into large chunks, and assign chunks of -/// digits to threads. Note that there is a delicate balance here, as large chunks can result in -/// uneven distributions of work among threads. -const DIGITS_PER_CHUNK: usize = 80; - -#[derive(Clone, Debug)] -pub struct MsmPrecomputation { - /// For each generator (in the order they were passed to `msm_precompute`), contains a vector - /// of powers, i.e. [(2^w)^i] for i < DIGITS. - // TODO: Use compressed coordinates here. - powers_per_generator: Vec>>, - - /// The window size. - w: usize, -} - -pub fn msm_precompute( - generators: &[ProjectivePoint], - w: usize, -) -> MsmPrecomputation { - MsmPrecomputation { - powers_per_generator: generators - .into_par_iter() - .map(|&g| precompute_single_generator(g, w)) - .collect(), - w, - } -} - -fn precompute_single_generator(g: ProjectivePoint, w: usize) -> Vec> { - let digits = (C::ScalarField::BITS + w - 1) / w; - let mut powers: Vec> = Vec::with_capacity(digits); - powers.push(g); - for i in 1..digits { - let mut power_i_proj = powers[i - 1]; - for _j in 0..w { - power_i_proj = power_i_proj.double(); - } - powers.push(power_i_proj); - } - ProjectivePoint::batch_to_affine(&powers) -} - -pub fn msm_parallel( - scalars: &[C::ScalarField], - generators: &[ProjectivePoint], - w: usize, -) -> ProjectivePoint { - let precomputation = msm_precompute(generators, w); - msm_execute_parallel(&precomputation, scalars) -} - -pub fn msm_execute( - precomputation: &MsmPrecomputation, - scalars: &[C::ScalarField], -) -> ProjectivePoint { - assert_eq!(precomputation.powers_per_generator.len(), scalars.len()); - let w = precomputation.w; - let digits = (C::ScalarField::BITS + w - 1) / w; - let base = 1 << w; - - // This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use - // extremely large windows, the repeated scans in Yao's method could be more expensive than the - // actual group operations. To avoid this, we store a multimap from each possible digit to the - // positions in which that digit occurs in the scalars. These positions have the form (i, j), - // where i is the index of the generator and j is an index into the digits of the scalar - // associated with that generator. - let mut digit_occurrences: Vec> = Vec::with_capacity(digits); - for _i in 0..base { - digit_occurrences.push(Vec::new()); - } - for (i, scalar) in scalars.iter().enumerate() { - let digits = to_digits::(scalar, w); - for (j, &digit) in digits.iter().enumerate() { - digit_occurrences[digit].push((i, j)); - } - } - - let mut y = ProjectivePoint::ZERO; - let mut u = ProjectivePoint::ZERO; - - for digit in (1..base).rev() { - for &(i, j) in &digit_occurrences[digit] { - u = u + precomputation.powers_per_generator[i][j]; - } - y = y + u; - } - - y -} - -pub fn msm_execute_parallel( - precomputation: &MsmPrecomputation, - scalars: &[C::ScalarField], -) -> ProjectivePoint { - assert_eq!(precomputation.powers_per_generator.len(), scalars.len()); - let w = precomputation.w; - let digits = (C::ScalarField::BITS + w - 1) / w; - let base = 1 << w; - - // This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use - // extremely large windows, the repeated scans in Yao's method could be more expensive than the - // actual group operations. To avoid this, we store a multimap from each possible digit to the - // positions in which that digit occurs in the scalars. These positions have the form (i, j), - // where i is the index of the generator and j is an index into the digits of the scalar - // associated with that generator. - let mut digit_occurrences: Vec> = Vec::with_capacity(digits); - for _i in 0..base { - digit_occurrences.push(Vec::new()); - } - for (i, scalar) in scalars.iter().enumerate() { - let digits = to_digits::(scalar, w); - for (j, &digit) in digits.iter().enumerate() { - digit_occurrences[digit].push((i, j)); - } - } - - // For each digit, we add up the powers associated with all occurrences that digit. - let digits: Vec = (0..base).collect(); - let digit_acc: Vec> = digits - .par_chunks(DIGITS_PER_CHUNK) - .flat_map(|chunk| { - let summations: Vec>> = chunk - .iter() - .map(|&digit| { - digit_occurrences[digit] - .iter() - .map(|&(i, j)| precomputation.powers_per_generator[i][j]) - .collect() - }) - .collect(); - affine_multisummation_best(summations) - }) - .collect(); - // println!("Computing the per-digit summations (in parallel) took {}s", start.elapsed().as_secs_f64()); - - let mut y = ProjectivePoint::ZERO; - let mut u = ProjectivePoint::ZERO; - for digit in (1..base).rev() { - u = u + digit_acc[digit]; - y = y + u; - } - // println!("Final summation (sequential) {}s", start.elapsed().as_secs_f64()); - y -} - -pub(crate) fn to_digits(x: &C::ScalarField, w: usize) -> Vec { - let scalar_bits = C::ScalarField::BITS; - let num_digits = (scalar_bits + w - 1) / w; - - // Convert x to a bool array. - let x_canonical: Vec<_> = x - .to_canonical_biguint() - .to_u64_digits() - .iter() - .cloned() - .pad_using(scalar_bits / 64, |_| 0) - .collect(); - let mut x_bits = Vec::with_capacity(scalar_bits); - for i in 0..scalar_bits { - x_bits.push((x_canonical[i / 64] >> (i as u64 % 64) & 1) != 0); - } - - let mut digits = Vec::with_capacity(num_digits); - for i in 0..num_digits { - let mut digit = 0; - for j in ((i * w)..((i + 1) * w).min(scalar_bits)).rev() { - digit <<= 1; - digit |= x_bits[j] as usize; - } - digits.push(digit); - } - digits -} - -#[cfg(test)] -mod tests { - use num::BigUint; - use plonky2_field::field_types::Field; - use plonky2_field::field_types::PrimeField; - use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - - use crate::curve::curve_msm::{msm_execute, msm_precompute, to_digits}; - use crate::curve::curve_types::Curve; - use crate::curve::secp256k1::Secp256K1; - - #[test] - fn test_to_digits() { - let x_canonical = [ - 0b10101010101010101010101010101010, - 0b10101010101010101010101010101010, - 0b11001100110011001100110011001100, - 0b11001100110011001100110011001100, - 0b11110000111100001111000011110000, - 0b11110000111100001111000011110000, - 0b00001111111111111111111111111111, - 0b11111111111111111111111111111111, - ]; - let x = Secp256K1Scalar::from_biguint(BigUint::from_slice(&x_canonical)); - assert_eq!(x.to_canonical_biguint().to_u32_digits(), x_canonical); - assert_eq!( - to_digits::(&x, 17), - vec![ - 0b01010101010101010, - 0b10101010101010101, - 0b01010101010101010, - 0b11001010101010101, - 0b01100110011001100, - 0b00110011001100110, - 0b10011001100110011, - 0b11110000110011001, - 0b01111000011110000, - 0b00111100001111000, - 0b00011110000111100, - 0b11111111111111110, - 0b01111111111111111, - 0b11111111111111000, - 0b11111111111111111, - 0b1, - ] - ); - } - - #[test] - fn test_msm() { - let w = 5; - - let generator_1 = Secp256K1::GENERATOR_PROJECTIVE; - let generator_2 = generator_1 + generator_1; - let generator_3 = generator_1 + generator_2; - - let scalar_1 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ - 11111111, 22222222, 33333333, 44444444, - ])); - let scalar_2 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ - 22222222, 22222222, 33333333, 44444444, - ])); - let scalar_3 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ - 33333333, 22222222, 33333333, 44444444, - ])); - - let generators = vec![generator_1, generator_2, generator_3]; - let scalars = vec![scalar_1, scalar_2, scalar_3]; - - let precomputation = msm_precompute(&generators, w); - let result_msm = msm_execute(&precomputation, &scalars); - - let result_naive = Secp256K1::convert(scalar_1) * generator_1 - + Secp256K1::convert(scalar_2) * generator_2 - + Secp256K1::convert(scalar_3) * generator_3; - - assert_eq!(result_msm, result_naive); - } -} diff --git a/plonky2/src/curve/curve_multiplication.rs b/plonky2/src/curve/curve_multiplication.rs deleted file mode 100644 index c6fbbd83..00000000 --- a/plonky2/src/curve/curve_multiplication.rs +++ /dev/null @@ -1,99 +0,0 @@ -use std::ops::Mul; - -use plonky2_field::field_types::Field; -use plonky2_field::field_types::PrimeField; - -use crate::curve::curve_types::{Curve, CurveScalar, ProjectivePoint}; - -const WINDOW_BITS: usize = 4; -const BASE: usize = 1 << WINDOW_BITS; - -fn digits_per_scalar() -> usize { - (C::ScalarField::BITS + WINDOW_BITS - 1) / WINDOW_BITS -} - -/// Precomputed state used for scalar x ProjectivePoint multiplications, -/// specific to a particular generator. -#[derive(Clone)] -pub struct MultiplicationPrecomputation { - /// [(2^w)^i] g for each i < digits_per_scalar. - powers: Vec>, -} - -impl ProjectivePoint { - pub fn mul_precompute(&self) -> MultiplicationPrecomputation { - let num_digits = digits_per_scalar::(); - let mut powers = Vec::with_capacity(num_digits); - powers.push(*self); - for i in 1..num_digits { - let mut power_i = powers[i - 1]; - for _j in 0..WINDOW_BITS { - power_i = power_i.double(); - } - powers.push(power_i); - } - - MultiplicationPrecomputation { powers } - } - - pub fn mul_with_precomputation( - &self, - scalar: C::ScalarField, - precomputation: MultiplicationPrecomputation, - ) -> Self { - // Yao's method; see https://koclab.cs.ucsb.edu/teaching/ecc/eccPapers/Doche-ch09.pdf - let precomputed_powers = precomputation.powers; - - let digits = to_digits::(&scalar); - - let mut y = ProjectivePoint::ZERO; - let mut u = ProjectivePoint::ZERO; - let mut all_summands = Vec::new(); - for j in (1..BASE).rev() { - let mut u_summands = Vec::new(); - for (i, &digit) in digits.iter().enumerate() { - if digit == j as u64 { - u_summands.push(precomputed_powers[i]); - } - } - all_summands.push(u_summands); - } - - let all_sums: Vec> = all_summands - .iter() - .cloned() - .map(|vec| vec.iter().fold(ProjectivePoint::ZERO, |a, &b| a + b)) - .collect(); - for i in 0..all_sums.len() { - u = u + all_sums[i]; - y = y + u; - } - y - } -} - -impl Mul> for CurveScalar { - type Output = ProjectivePoint; - - fn mul(self, rhs: ProjectivePoint) -> Self::Output { - let precomputation = rhs.mul_precompute(); - rhs.mul_with_precomputation(self.0, precomputation) - } -} - -#[allow(clippy::assertions_on_constants)] -fn to_digits(x: &C::ScalarField) -> Vec { - debug_assert!( - 64 % WINDOW_BITS == 0, - "For simplicity, only power-of-two window sizes are handled for now" - ); - let digits_per_u64 = 64 / WINDOW_BITS; - let mut digits = Vec::with_capacity(digits_per_scalar::()); - for limb in x.to_canonical_biguint().to_u64_digits() { - for j in 0..digits_per_u64 { - digits.push((limb >> (j * WINDOW_BITS) as u64) % BASE as u64); - } - } - - digits -} diff --git a/plonky2/src/curve/curve_summation.rs b/plonky2/src/curve/curve_summation.rs deleted file mode 100644 index 7ea01524..00000000 --- a/plonky2/src/curve/curve_summation.rs +++ /dev/null @@ -1,239 +0,0 @@ -use std::iter::Sum; - -use plonky2_field::field_types::Field; -use plonky2_field::ops::Square; - -use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; - -impl Sum> for ProjectivePoint { - fn sum>>(iter: I) -> ProjectivePoint { - let points: Vec<_> = iter.collect(); - affine_summation_best(points) - } -} - -impl Sum for ProjectivePoint { - fn sum>>(iter: I) -> ProjectivePoint { - iter.fold(ProjectivePoint::ZERO, |acc, x| acc + x) - } -} - -pub fn affine_summation_best(summation: Vec>) -> ProjectivePoint { - let result = affine_multisummation_best(vec![summation]); - debug_assert_eq!(result.len(), 1); - result[0] -} - -pub fn affine_multisummation_best( - summations: Vec>>, -) -> Vec> { - let pairwise_sums: usize = summations.iter().map(|summation| summation.len() / 2).sum(); - - // This threshold is chosen based on data from the summation benchmarks. - if pairwise_sums < 70 { - affine_multisummation_pairwise(summations) - } else { - affine_multisummation_batch_inversion(summations) - } -} - -/// Adds each pair of points using an affine + affine = projective formula, then adds up the -/// intermediate sums using a projective formula. -pub fn affine_multisummation_pairwise( - summations: Vec>>, -) -> Vec> { - summations - .into_iter() - .map(affine_summation_pairwise) - .collect() -} - -/// Adds each pair of points using an affine + affine = projective formula, then adds up the -/// intermediate sums using a projective formula. -pub fn affine_summation_pairwise(points: Vec>) -> ProjectivePoint { - let mut reduced_points: Vec> = Vec::new(); - for chunk in points.chunks(2) { - match chunk.len() { - 1 => reduced_points.push(chunk[0].to_projective()), - 2 => reduced_points.push(chunk[0] + chunk[1]), - _ => panic!(), - } - } - // TODO: Avoid copying (deref) - reduced_points - .iter() - .fold(ProjectivePoint::ZERO, |sum, x| sum + *x) -} - -/// Computes several summations of affine points by applying an affine group law, except that the -/// divisions are batched via Montgomery's trick. -pub fn affine_summation_batch_inversion( - summation: Vec>, -) -> ProjectivePoint { - let result = affine_multisummation_batch_inversion(vec![summation]); - debug_assert_eq!(result.len(), 1); - result[0] -} - -/// Computes several summations of affine points by applying an affine group law, except that the -/// divisions are batched via Montgomery's trick. -pub fn affine_multisummation_batch_inversion( - summations: Vec>>, -) -> Vec> { - let mut elements_to_invert = Vec::new(); - - // For each pair of points, (x1, y1) and (x2, y2), that we're going to add later, we want to - // invert either y (if the points are equal) or x1 - x2 (otherwise). We will use these later. - for summation in &summations { - let n = summation.len(); - // The special case for n=0 is to avoid underflow. - let range_end = if n == 0 { 0 } else { n - 1 }; - - for i in (0..range_end).step_by(2) { - let p1 = summation[i]; - let p2 = summation[i + 1]; - let AffinePoint { - x: x1, - y: y1, - zero: zero1, - } = p1; - let AffinePoint { - x: x2, - y: _y2, - zero: zero2, - } = p2; - - if zero1 || zero2 || p1 == -p2 { - // These are trivial cases where we won't need any inverse. - } else if p1 == p2 { - elements_to_invert.push(y1.double()); - } else { - elements_to_invert.push(x1 - x2); - } - } - } - - let inverses: Vec = - C::BaseField::batch_multiplicative_inverse(&elements_to_invert); - - let mut all_reduced_points = Vec::with_capacity(summations.len()); - let mut inverse_index = 0; - for summation in summations { - let n = summation.len(); - let mut reduced_points = Vec::with_capacity((n + 1) / 2); - - // The special case for n=0 is to avoid underflow. - let range_end = if n == 0 { 0 } else { n - 1 }; - - for i in (0..range_end).step_by(2) { - let p1 = summation[i]; - let p2 = summation[i + 1]; - let AffinePoint { - x: x1, - y: y1, - zero: zero1, - } = p1; - let AffinePoint { - x: x2, - y: y2, - zero: zero2, - } = p2; - - let sum = if zero1 { - p2 - } else if zero2 { - p1 - } else if p1 == -p2 { - AffinePoint::ZERO - } else { - // It's a non-trivial case where we need one of the inverses we computed earlier. - let inverse = inverses[inverse_index]; - inverse_index += 1; - - if p1 == p2 { - // This is the doubling case. - let mut numerator = x1.square().triple(); - if C::A.is_nonzero() { - numerator += C::A; - } - let quotient = numerator * inverse; - let x3 = quotient.square() - x1.double(); - let y3 = quotient * (x1 - x3) - y1; - AffinePoint::nonzero(x3, y3) - } else { - // This is the general case. We use the incomplete addition formulas 4.3 and 4.4. - let quotient = (y1 - y2) * inverse; - let x3 = quotient.square() - x1 - x2; - let y3 = quotient * (x1 - x3) - y1; - AffinePoint::nonzero(x3, y3) - } - }; - reduced_points.push(sum); - } - - // If n is odd, the last point was not part of a pair. - if n % 2 == 1 { - reduced_points.push(summation[n - 1]); - } - - all_reduced_points.push(reduced_points); - } - - // We should have consumed all of the inverses from the batch computation. - debug_assert_eq!(inverse_index, inverses.len()); - - // Recurse with our smaller set of points. - affine_multisummation_best(all_reduced_points) -} - -#[cfg(test)] -mod tests { - use crate::curve::curve_summation::{ - affine_summation_batch_inversion, affine_summation_pairwise, - }; - use crate::curve::curve_types::{Curve, ProjectivePoint}; - use crate::curve::secp256k1::Secp256K1; - - #[test] - fn test_pairwise_affine_summation() { - let g_affine = Secp256K1::GENERATOR_AFFINE; - let g2_affine = (g_affine + g_affine).to_affine(); - let g3_affine = (g_affine + g_affine + g_affine).to_affine(); - let g2_proj = g2_affine.to_projective(); - let g3_proj = g3_affine.to_projective(); - assert_eq!( - affine_summation_pairwise::(vec![g_affine, g_affine]), - g2_proj - ); - assert_eq!( - affine_summation_pairwise::(vec![g_affine, g2_affine]), - g3_proj - ); - assert_eq!( - affine_summation_pairwise::(vec![g_affine, g_affine, g_affine]), - g3_proj - ); - assert_eq!( - affine_summation_pairwise::(vec![]), - ProjectivePoint::ZERO - ); - } - - #[test] - fn test_pairwise_affine_summation_batch_inversion() { - let g = Secp256K1::GENERATOR_AFFINE; - let g_proj = g.to_projective(); - assert_eq!( - affine_summation_batch_inversion::(vec![g, g]), - g_proj + g_proj - ); - assert_eq!( - affine_summation_batch_inversion::(vec![g, g, g]), - g_proj + g_proj + g_proj - ); - assert_eq!( - affine_summation_batch_inversion::(vec![]), - ProjectivePoint::ZERO - ); - } -} diff --git a/plonky2/src/curve/curve_types.rs b/plonky2/src/curve/curve_types.rs deleted file mode 100644 index 264120c7..00000000 --- a/plonky2/src/curve/curve_types.rs +++ /dev/null @@ -1,282 +0,0 @@ -use std::fmt::Debug; -use std::hash::Hash; -use std::ops::Neg; - -use plonky2_field::field_types::{Field, PrimeField}; -use plonky2_field::ops::Square; -use serde::{Deserialize, Serialize}; - -// To avoid implementation conflicts from associated types, -// see https://github.com/rust-lang/rust/issues/20400 -pub struct CurveScalar(pub ::ScalarField); - -/// A short Weierstrass curve. -pub trait Curve: 'static + Sync + Sized + Copy + Debug { - type BaseField: PrimeField; - type ScalarField: PrimeField; - - const A: Self::BaseField; - const B: Self::BaseField; - - const GENERATOR_AFFINE: AffinePoint; - - const GENERATOR_PROJECTIVE: ProjectivePoint = ProjectivePoint { - x: Self::GENERATOR_AFFINE.x, - y: Self::GENERATOR_AFFINE.y, - z: Self::BaseField::ONE, - }; - - fn convert(x: Self::ScalarField) -> CurveScalar { - CurveScalar(x) - } - - fn is_safe_curve() -> bool { - // Added additional check to prevent using vulnerabilties in case a discriminant is equal to 0. - (Self::A.cube().double().double() + Self::B.square().triple().triple().triple()) - .is_nonzero() - } -} - -/// A point on a short Weierstrass curve, represented in affine coordinates. -#[derive(Copy, Clone, Debug, Deserialize, Serialize)] -pub struct AffinePoint { - pub x: C::BaseField, - pub y: C::BaseField, - pub zero: bool, -} - -impl AffinePoint { - pub const ZERO: Self = Self { - x: C::BaseField::ZERO, - y: C::BaseField::ZERO, - zero: true, - }; - - pub fn nonzero(x: C::BaseField, y: C::BaseField) -> Self { - let point = Self { x, y, zero: false }; - debug_assert!(point.is_valid()); - point - } - - pub fn is_valid(&self) -> bool { - let Self { x, y, zero } = *self; - zero || y.square() == x.cube() + C::A * x + C::B - } - - pub fn to_projective(&self) -> ProjectivePoint { - let Self { x, y, zero } = *self; - let z = if zero { - C::BaseField::ZERO - } else { - C::BaseField::ONE - }; - - ProjectivePoint { x, y, z } - } - - pub fn batch_to_projective(affine_points: &[Self]) -> Vec> { - affine_points.iter().map(Self::to_projective).collect() - } - - pub fn double(&self) -> Self { - let AffinePoint { x: x1, y: y1, zero } = *self; - - if zero { - return AffinePoint::ZERO; - } - - let double_y = y1.double(); - let inv_double_y = double_y.inverse(); // (2y)^(-1) - let triple_xx = x1.square().triple(); // 3x^2 - let lambda = (triple_xx + C::A) * inv_double_y; - let x3 = lambda.square() - self.x.double(); - let y3 = lambda * (x1 - x3) - y1; - - Self { - x: x3, - y: y3, - zero: false, - } - } -} - -impl PartialEq for AffinePoint { - fn eq(&self, other: &Self) -> bool { - let AffinePoint { - x: x1, - y: y1, - zero: zero1, - } = *self; - let AffinePoint { - x: x2, - y: y2, - zero: zero2, - } = *other; - if zero1 || zero2 { - return zero1 == zero2; - } - x1 == x2 && y1 == y2 - } -} - -impl Eq for AffinePoint {} - -impl Hash for AffinePoint { - fn hash(&self, state: &mut H) { - if self.zero { - self.zero.hash(state); - } else { - self.x.hash(state); - self.y.hash(state); - } - } -} - -/// A point on a short Weierstrass curve, represented in projective coordinates. -#[derive(Copy, Clone, Debug)] -pub struct ProjectivePoint { - pub x: C::BaseField, - pub y: C::BaseField, - pub z: C::BaseField, -} - -impl ProjectivePoint { - pub const ZERO: Self = Self { - x: C::BaseField::ZERO, - y: C::BaseField::ONE, - z: C::BaseField::ZERO, - }; - - pub fn nonzero(x: C::BaseField, y: C::BaseField, z: C::BaseField) -> Self { - let point = Self { x, y, z }; - debug_assert!(point.is_valid()); - point - } - - pub fn is_valid(&self) -> bool { - let Self { x, y, z } = *self; - z.is_zero() || y.square() * z == x.cube() + C::A * x * z.square() + C::B * z.cube() - } - - pub fn to_affine(&self) -> AffinePoint { - let Self { x, y, z } = *self; - if z == C::BaseField::ZERO { - AffinePoint::ZERO - } else { - let z_inv = z.inverse(); - AffinePoint::nonzero(x * z_inv, y * z_inv) - } - } - - pub fn batch_to_affine(proj_points: &[Self]) -> Vec> { - let n = proj_points.len(); - let zs: Vec = proj_points.iter().map(|pp| pp.z).collect(); - let z_invs = C::BaseField::batch_multiplicative_inverse(&zs); - - let mut result = Vec::with_capacity(n); - for i in 0..n { - let Self { x, y, z } = proj_points[i]; - result.push(if z == C::BaseField::ZERO { - AffinePoint::ZERO - } else { - let z_inv = z_invs[i]; - AffinePoint::nonzero(x * z_inv, y * z_inv) - }); - } - result - } - - // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/doubling/dbl-2007-bl - pub fn double(&self) -> Self { - let Self { x, y, z } = *self; - if z == C::BaseField::ZERO { - return ProjectivePoint::ZERO; - } - - let xx = x.square(); - let zz = z.square(); - let mut w = xx.triple(); - if C::A.is_nonzero() { - w += C::A * zz; - } - let s = y.double() * z; - let r = y * s; - let rr = r.square(); - let b = (x + r).square() - (xx + rr); - let h = w.square() - b.double(); - let x3 = h * s; - let y3 = w * (b - h) - rr.double(); - let z3 = s.cube(); - Self { - x: x3, - y: y3, - z: z3, - } - } - - pub fn add_slices(a: &[Self], b: &[Self]) -> Vec { - assert_eq!(a.len(), b.len()); - a.iter() - .zip(b.iter()) - .map(|(&a_i, &b_i)| a_i + b_i) - .collect() - } - - pub fn neg(&self) -> Self { - Self { - x: self.x, - y: -self.y, - z: self.z, - } - } -} - -impl PartialEq for ProjectivePoint { - fn eq(&self, other: &Self) -> bool { - let ProjectivePoint { - x: x1, - y: y1, - z: z1, - } = *self; - let ProjectivePoint { - x: x2, - y: y2, - z: z2, - } = *other; - if z1 == C::BaseField::ZERO || z2 == C::BaseField::ZERO { - return z1 == z2; - } - - // We want to compare (x1/z1, y1/z1) == (x2/z2, y2/z2). - // But to avoid field division, it is better to compare (x1*z2, y1*z2) == (x2*z1, y2*z1). - x1 * z2 == x2 * z1 && y1 * z2 == y2 * z1 - } -} - -impl Eq for ProjectivePoint {} - -impl Neg for AffinePoint { - type Output = AffinePoint; - - fn neg(self) -> Self::Output { - let AffinePoint { x, y, zero } = self; - AffinePoint { x, y: -y, zero } - } -} - -impl Neg for ProjectivePoint { - type Output = ProjectivePoint; - - fn neg(self) -> Self::Output { - let ProjectivePoint { x, y, z } = self; - ProjectivePoint { x, y: -y, z } - } -} - -pub fn base_to_scalar(x: C::BaseField) -> C::ScalarField { - C::ScalarField::from_biguint(x.to_canonical_biguint()) -} - -pub fn scalar_to_base(x: C::ScalarField) -> C::BaseField { - C::BaseField::from_biguint(x.to_canonical_biguint()) -} diff --git a/plonky2/src/curve/ecdsa.rs b/plonky2/src/curve/ecdsa.rs deleted file mode 100644 index cabe038a..00000000 --- a/plonky2/src/curve/ecdsa.rs +++ /dev/null @@ -1,78 +0,0 @@ -use serde::{Deserialize, Serialize}; - -use crate::curve::curve_msm::msm_parallel; -use crate::curve::curve_types::{base_to_scalar, AffinePoint, Curve, CurveScalar}; -use crate::field::field_types::Field; - -#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] -pub struct ECDSASignature { - pub r: C::ScalarField, - pub s: C::ScalarField, -} - -#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] -pub struct ECDSASecretKey(pub C::ScalarField); - -#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] -pub struct ECDSAPublicKey(pub AffinePoint); - -pub fn sign_message(msg: C::ScalarField, sk: ECDSASecretKey) -> ECDSASignature { - let (k, rr) = { - let mut k = C::ScalarField::rand(); - let mut rr = (CurveScalar(k) * C::GENERATOR_PROJECTIVE).to_affine(); - while rr.x == C::BaseField::ZERO { - k = C::ScalarField::rand(); - rr = (CurveScalar(k) * C::GENERATOR_PROJECTIVE).to_affine(); - } - (k, rr) - }; - let r = base_to_scalar::(rr.x); - - let s = k.inverse() * (msg + r * sk.0); - - ECDSASignature { r, s } -} - -pub fn verify_message( - msg: C::ScalarField, - sig: ECDSASignature, - pk: ECDSAPublicKey, -) -> bool { - let ECDSASignature { r, s } = sig; - - assert!(pk.0.is_valid()); - - let c = s.inverse(); - let u1 = msg * c; - let u2 = r * c; - - let g = C::GENERATOR_PROJECTIVE; - let w = 5; // Experimentally fastest - let point_proj = msm_parallel(&[u1, u2], &[g, pk.0.to_projective()], w); - let point = point_proj.to_affine(); - - let x = base_to_scalar::(point.x); - r == x -} - -#[cfg(test)] -mod tests { - use crate::curve::curve_types::{Curve, CurveScalar}; - use crate::curve::ecdsa::{sign_message, verify_message, ECDSAPublicKey, ECDSASecretKey}; - use crate::curve::secp256k1::Secp256K1; - use crate::field::field_types::Field; - use crate::field::secp256k1_scalar::Secp256K1Scalar; - - #[test] - fn test_ecdsa_native() { - type C = Secp256K1; - - let msg = Secp256K1Scalar::rand(); - let sk = ECDSASecretKey(Secp256K1Scalar::rand()); - let pk = ECDSAPublicKey((CurveScalar(sk.0) * C::GENERATOR_PROJECTIVE).to_affine()); - - let sig = sign_message(msg, sk); - let result = verify_message(msg, sig, pk); - assert!(result); - } -} diff --git a/plonky2/src/curve/glv.rs b/plonky2/src/curve/glv.rs deleted file mode 100644 index aeeb463e..00000000 --- a/plonky2/src/curve/glv.rs +++ /dev/null @@ -1,136 +0,0 @@ -use num::rational::Ratio; -use num::BigUint; -use plonky2_field::field_types::{Field, PrimeField}; -use plonky2_field::secp256k1_base::Secp256K1Base; -use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - -use crate::curve::curve_msm::msm_parallel; -use crate::curve::curve_types::{AffinePoint, ProjectivePoint}; -use crate::curve::secp256k1::Secp256K1; - -pub const GLV_BETA: Secp256K1Base = Secp256K1Base([ - 13923278643952681454, - 11308619431505398165, - 7954561588662645993, - 8856726876819556112, -]); - -pub const GLV_S: Secp256K1Scalar = Secp256K1Scalar([ - 16069571880186789234, - 1310022930574435960, - 11900229862571533402, - 6008836872998760672, -]); - -const A1: Secp256K1Scalar = Secp256K1Scalar([16747920425669159701, 3496713202691238861, 0, 0]); - -const MINUS_B1: Secp256K1Scalar = - Secp256K1Scalar([8022177200260244675, 16448129721693014056, 0, 0]); - -const A2: Secp256K1Scalar = Secp256K1Scalar([6323353552219852760, 1498098850674701302, 1, 0]); - -const B2: Secp256K1Scalar = Secp256K1Scalar([16747920425669159701, 3496713202691238861, 0, 0]); - -pub fn decompose_secp256k1_scalar( - k: Secp256K1Scalar, -) -> (Secp256K1Scalar, Secp256K1Scalar, bool, bool) { - let p = Secp256K1Scalar::order(); - let c1_biguint = Ratio::new( - B2.to_canonical_biguint() * k.to_canonical_biguint(), - p.clone(), - ) - .round() - .to_integer(); - let c1 = Secp256K1Scalar::from_biguint(c1_biguint); - let c2_biguint = Ratio::new( - MINUS_B1.to_canonical_biguint() * k.to_canonical_biguint(), - p.clone(), - ) - .round() - .to_integer(); - let c2 = Secp256K1Scalar::from_biguint(c2_biguint); - - let k1_raw = k - c1 * A1 - c2 * A2; - let k2_raw = c1 * MINUS_B1 - c2 * B2; - debug_assert!(k1_raw + GLV_S * k2_raw == k); - - let two = BigUint::from_slice(&[2]); - let k1_neg = k1_raw.to_canonical_biguint() > p.clone() / two.clone(); - let k1 = if k1_neg { - Secp256K1Scalar::from_biguint(p.clone() - k1_raw.to_canonical_biguint()) - } else { - k1_raw - }; - let k2_neg = k2_raw.to_canonical_biguint() > p.clone() / two; - let k2 = if k2_neg { - Secp256K1Scalar::from_biguint(p - k2_raw.to_canonical_biguint()) - } else { - k2_raw - }; - - (k1, k2, k1_neg, k2_neg) -} - -pub fn glv_mul(p: ProjectivePoint, k: Secp256K1Scalar) -> ProjectivePoint { - let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); - /*let one = Secp256K1Scalar::ONE; - let m1 = if k1_neg { -one } else { one }; - let m2 = if k2_neg { -one } else { one }; - assert!(k1 * m1 + S * k2 * m2 == k);*/ - - let p_affine = p.to_affine(); - let sp = AffinePoint:: { - x: p_affine.x * GLV_BETA, - y: p_affine.y, - zero: p_affine.zero, - }; - - let first = if k1_neg { p.neg() } else { p }; - let second = if k2_neg { - sp.to_projective().neg() - } else { - sp.to_projective() - }; - - msm_parallel(&[k1, k2], &[first, second], 5) -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2_field::field_types::Field; - use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - - use crate::curve::curve_types::{Curve, CurveScalar}; - use crate::curve::glv::{decompose_secp256k1_scalar, glv_mul, GLV_S}; - use crate::curve::secp256k1::Secp256K1; - - #[test] - fn test_glv_decompose() -> Result<()> { - let k = Secp256K1Scalar::rand(); - let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); - let one = Secp256K1Scalar::ONE; - let m1 = if k1_neg { -one } else { one }; - let m2 = if k2_neg { -one } else { one }; - - assert!(k1 * m1 + GLV_S * k2 * m2 == k); - - Ok(()) - } - - #[test] - fn test_glv_mul() -> Result<()> { - for _ in 0..20 { - let k = Secp256K1Scalar::rand(); - - let p = CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE; - - let kp = CurveScalar(k) * p; - let glv = glv_mul(p, k); - - assert!(kp == glv); - } - - Ok(()) - } -} diff --git a/plonky2/src/curve/mod.rs b/plonky2/src/curve/mod.rs deleted file mode 100644 index 1984b0c6..00000000 --- a/plonky2/src/curve/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -pub mod curve_adds; -pub mod curve_msm; -pub mod curve_multiplication; -pub mod curve_summation; -pub mod curve_types; -pub mod ecdsa; -pub mod glv; -pub mod secp256k1; diff --git a/plonky2/src/curve/secp256k1.rs b/plonky2/src/curve/secp256k1.rs deleted file mode 100644 index 18040dae..00000000 --- a/plonky2/src/curve/secp256k1.rs +++ /dev/null @@ -1,101 +0,0 @@ -use plonky2_field::field_types::Field; -use plonky2_field::secp256k1_base::Secp256K1Base; -use plonky2_field::secp256k1_scalar::Secp256K1Scalar; -use serde::{Deserialize, Serialize}; - -use crate::curve::curve_types::{AffinePoint, Curve}; - -#[derive(Debug, Copy, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)] -pub struct Secp256K1; - -impl Curve for Secp256K1 { - type BaseField = Secp256K1Base; - type ScalarField = Secp256K1Scalar; - - const A: Secp256K1Base = Secp256K1Base::ZERO; - const B: Secp256K1Base = Secp256K1Base([7, 0, 0, 0]); - const GENERATOR_AFFINE: AffinePoint = AffinePoint { - x: SECP256K1_GENERATOR_X, - y: SECP256K1_GENERATOR_Y, - zero: false, - }; -} - -// 55066263022277343669578718895168534326250603453777594175500187360389116729240 -const SECP256K1_GENERATOR_X: Secp256K1Base = Secp256K1Base([ - 0x59F2815B16F81798, - 0x029BFCDB2DCE28D9, - 0x55A06295CE870B07, - 0x79BE667EF9DCBBAC, -]); - -/// 32670510020758816978083085130507043184471273380659243275938904335757337482424 -const SECP256K1_GENERATOR_Y: Secp256K1Base = Secp256K1Base([ - 0x9C47D08FFB10D4B8, - 0xFD17B448A6855419, - 0x5DA4FBFC0E1108A8, - 0x483ADA7726A3C465, -]); - -#[cfg(test)] -mod tests { - use num::BigUint; - use plonky2_field::field_types::Field; - use plonky2_field::field_types::PrimeField; - use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - - use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; - use crate::curve::secp256k1::Secp256K1; - - #[test] - fn test_generator() { - let g = Secp256K1::GENERATOR_AFFINE; - assert!(g.is_valid()); - - let neg_g = AffinePoint:: { - x: g.x, - y: -g.y, - zero: g.zero, - }; - assert!(neg_g.is_valid()); - } - - #[test] - fn test_naive_multiplication() { - let g = Secp256K1::GENERATOR_PROJECTIVE; - let ten = Secp256K1Scalar::from_canonical_u64(10); - let product = mul_naive(ten, g); - let sum = g + g + g + g + g + g + g + g + g + g; - assert_eq!(product, sum); - } - - #[test] - fn test_g1_multiplication() { - let lhs = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ - 1111, 2222, 3333, 4444, 5555, 6666, 7777, 8888, - ])); - assert_eq!( - Secp256K1::convert(lhs) * Secp256K1::GENERATOR_PROJECTIVE, - mul_naive(lhs, Secp256K1::GENERATOR_PROJECTIVE) - ); - } - - /// A simple, somewhat inefficient implementation of multiplication which is used as a reference - /// for correctness. - fn mul_naive( - lhs: Secp256K1Scalar, - rhs: ProjectivePoint, - ) -> ProjectivePoint { - let mut g = rhs; - let mut sum = ProjectivePoint::ZERO; - for limb in lhs.to_canonical_biguint().to_u64_digits().iter() { - for j in 0..64 { - if (limb >> j & 1u64) != 0u64 { - sum = sum + g; - } - g = g.double(); - } - } - sum - } -} diff --git a/plonky2/src/gadgets/biguint.rs b/plonky2/src/gadgets/biguint.rs deleted file mode 100644 index c9ad7280..00000000 --- a/plonky2/src/gadgets/biguint.rs +++ /dev/null @@ -1,418 +0,0 @@ -use std::marker::PhantomData; - -use num::{BigUint, Integer, Zero}; -use plonky2_field::extension_field::Extendable; - -use crate::gadgets::arithmetic_u32::U32Target; -use crate::hash::hash_types::RichField; -use crate::iop::generator::{GeneratedValues, SimpleGenerator}; -use crate::iop::target::{BoolTarget, Target}; -use crate::iop::witness::{PartitionWitness, Witness}; -use crate::plonk::circuit_builder::CircuitBuilder; - -#[derive(Clone, Debug)] -pub struct BigUintTarget { - pub limbs: Vec, -} - -impl BigUintTarget { - pub fn num_limbs(&self) -> usize { - self.limbs.len() - } - - pub fn get_limb(&self, i: usize) -> U32Target { - self.limbs[i] - } -} - -impl, const D: usize> CircuitBuilder { - pub fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget { - let limb_values = value.to_u32_digits(); - let limbs = limb_values.iter().map(|&l| self.constant_u32(l)).collect(); - - BigUintTarget { limbs } - } - - pub fn zero_biguint(&mut self) -> BigUintTarget { - self.constant_biguint(&BigUint::zero()) - } - - pub fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) { - let min_limbs = lhs.num_limbs().min(rhs.num_limbs()); - for i in 0..min_limbs { - self.connect_u32(lhs.get_limb(i), rhs.get_limb(i)); - } - - for i in min_limbs..lhs.num_limbs() { - self.assert_zero_u32(lhs.get_limb(i)); - } - for i in min_limbs..rhs.num_limbs() { - self.assert_zero_u32(rhs.get_limb(i)); - } - } - - pub fn pad_biguints( - &mut self, - a: &BigUintTarget, - b: &BigUintTarget, - ) -> (BigUintTarget, BigUintTarget) { - if a.num_limbs() > b.num_limbs() { - let mut padded_b = b.clone(); - for _ in b.num_limbs()..a.num_limbs() { - padded_b.limbs.push(self.zero_u32()); - } - - (a.clone(), padded_b) - } else { - let mut padded_a = a.clone(); - for _ in a.num_limbs()..b.num_limbs() { - padded_a.limbs.push(self.zero_u32()); - } - - (padded_a, b.clone()) - } - } - - pub fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget { - let (a, b) = self.pad_biguints(a, b); - - self.list_le_u32(a.limbs, b.limbs) - } - - pub fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget { - let limbs = self.add_virtual_u32_targets(num_limbs); - - BigUintTarget { limbs } - } - - // Add two `BigUintTarget`s. - pub fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { - let num_limbs = a.num_limbs().max(b.num_limbs()); - - let mut combined_limbs = vec![]; - let mut carry = self.zero_u32(); - for i in 0..num_limbs { - let a_limb = (i < a.num_limbs()) - .then(|| a.limbs[i]) - .unwrap_or_else(|| self.zero_u32()); - let b_limb = (i < b.num_limbs()) - .then(|| b.limbs[i]) - .unwrap_or_else(|| self.zero_u32()); - - let (new_limb, new_carry) = self.add_many_u32(&[carry, a_limb, b_limb]); - carry = new_carry; - combined_limbs.push(new_limb); - } - combined_limbs.push(carry); - - BigUintTarget { - limbs: combined_limbs, - } - } - - // Subtract two `BigUintTarget`s. We assume that the first is larger than the second. - pub fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { - let (a, b) = self.pad_biguints(a, b); - let num_limbs = a.limbs.len(); - - let mut result_limbs = vec![]; - - let mut borrow = self.zero_u32(); - for i in 0..num_limbs { - let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow); - result_limbs.push(result); - borrow = new_borrow; - } - // Borrow should be zero here. - - BigUintTarget { - limbs: result_limbs, - } - } - - pub fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { - let total_limbs = a.limbs.len() + b.limbs.len(); - - let mut to_add = vec![vec![]; total_limbs]; - for i in 0..a.limbs.len() { - for j in 0..b.limbs.len() { - let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]); - to_add[i + j].push(product); - to_add[i + j + 1].push(carry); - } - } - - let mut combined_limbs = vec![]; - let mut carry = self.zero_u32(); - for summands in &mut to_add { - let (new_result, new_carry) = self.add_u32s_with_carry(summands, carry); - combined_limbs.push(new_result); - carry = new_carry; - } - combined_limbs.push(carry); - - BigUintTarget { - limbs: combined_limbs, - } - } - - pub fn mul_biguint_by_bool(&mut self, a: &BigUintTarget, b: BoolTarget) -> BigUintTarget { - let t = b.target; - - BigUintTarget { - limbs: a - .limbs - .iter() - .map(|&l| U32Target(self.mul(l.0, t))) - .collect(), - } - } - - // Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function). - pub fn mul_add_biguint( - &mut self, - x: &BigUintTarget, - y: &BigUintTarget, - z: &BigUintTarget, - ) -> BigUintTarget { - let prod = self.mul_biguint(x, y); - self.add_biguint(&prod, z) - } - - pub fn div_rem_biguint( - &mut self, - a: &BigUintTarget, - b: &BigUintTarget, - ) -> (BigUintTarget, BigUintTarget) { - let a_len = a.limbs.len(); - let b_len = b.limbs.len(); - let div_num_limbs = if b_len > a_len + 1 { - 0 - } else { - a_len - b_len + 1 - }; - let div = self.add_virtual_biguint_target(div_num_limbs); - let rem = self.add_virtual_biguint_target(b_len); - - self.add_simple_generator(BigUintDivRemGenerator:: { - a: a.clone(), - b: b.clone(), - div: div.clone(), - rem: rem.clone(), - _phantom: PhantomData, - }); - - let div_b = self.mul_biguint(&div, b); - let div_b_plus_rem = self.add_biguint(&div_b, &rem); - self.connect_biguint(a, &div_b_plus_rem); - - let cmp_rem_b = self.cmp_biguint(&rem, b); - self.assert_one(cmp_rem_b.target); - - (div, rem) - } - - pub fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { - let (div, _rem) = self.div_rem_biguint(a, b); - div - } - - pub fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { - let (_div, rem) = self.div_rem_biguint(a, b); - rem - } -} - -#[derive(Debug)] -struct BigUintDivRemGenerator, const D: usize> { - a: BigUintTarget, - b: BigUintTarget, - div: BigUintTarget, - rem: BigUintTarget, - _phantom: PhantomData, -} - -impl, const D: usize> SimpleGenerator - for BigUintDivRemGenerator -{ - fn dependencies(&self) -> Vec { - self.a - .limbs - .iter() - .chain(&self.b.limbs) - .map(|&l| l.0) - .collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = witness.get_biguint_target(self.a.clone()); - let b = witness.get_biguint_target(self.b.clone()); - let (div, rem) = a.div_rem(&b); - - out_buffer.set_biguint_target(self.div.clone(), div); - out_buffer.set_biguint_target(self.rem.clone(), rem); - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use num::{BigUint, FromPrimitive, Integer}; - use rand::Rng; - - use crate::iop::witness::Witness; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::{ - iop::witness::PartialWitness, - plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, verifier::verify}, - }; - - #[test] - fn test_biguint_add() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - let mut rng = rand::thread_rng(); - - let x_value = BigUint::from_u128(rng.gen()).unwrap(); - let y_value = BigUint::from_u128(rng.gen()).unwrap(); - let expected_z_value = &x_value + &y_value; - - let config = CircuitConfig::standard_recursion_config(); - let mut pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len()); - let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len()); - let z = builder.add_biguint(&x, &y); - let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); - builder.connect_biguint(&z, &expected_z); - - pw.set_biguint_target(&x, &x_value); - pw.set_biguint_target(&y, &y_value); - pw.set_biguint_target(&expected_z, &expected_z_value); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - fn test_biguint_sub() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - let mut rng = rand::thread_rng(); - - let mut x_value = BigUint::from_u128(rng.gen()).unwrap(); - let mut y_value = BigUint::from_u128(rng.gen()).unwrap(); - if y_value > x_value { - (x_value, y_value) = (y_value, x_value); - } - let expected_z_value = &x_value - &y_value; - - let config = CircuitConfig::standard_recursion_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.constant_biguint(&x_value); - let y = builder.constant_biguint(&y_value); - let z = builder.sub_biguint(&x, &y); - let expected_z = builder.constant_biguint(&expected_z_value); - - builder.connect_biguint(&z, &expected_z); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - fn test_biguint_mul() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - let mut rng = rand::thread_rng(); - - let x_value = BigUint::from_u128(rng.gen()).unwrap(); - let y_value = BigUint::from_u128(rng.gen()).unwrap(); - let expected_z_value = &x_value * &y_value; - - let config = CircuitConfig::standard_recursion_config(); - let mut pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len()); - let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len()); - let z = builder.mul_biguint(&x, &y); - let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); - builder.connect_biguint(&z, &expected_z); - - pw.set_biguint_target(&x, &x_value); - pw.set_biguint_target(&y, &y_value); - pw.set_biguint_target(&expected_z, &expected_z_value); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - fn test_biguint_cmp() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - let mut rng = rand::thread_rng(); - - let x_value = BigUint::from_u128(rng.gen()).unwrap(); - let y_value = BigUint::from_u128(rng.gen()).unwrap(); - - let config = CircuitConfig::standard_recursion_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.constant_biguint(&x_value); - let y = builder.constant_biguint(&y_value); - let cmp = builder.cmp_biguint(&x, &y); - let expected_cmp = builder.constant_bool(x_value <= y_value); - - builder.connect(cmp.target, expected_cmp.target); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - fn test_biguint_div_rem() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - let mut rng = rand::thread_rng(); - - let mut x_value = BigUint::from_u128(rng.gen()).unwrap(); - let mut y_value = BigUint::from_u128(rng.gen()).unwrap(); - if y_value > x_value { - (x_value, y_value) = (y_value, x_value); - } - let (expected_div_value, expected_rem_value) = x_value.div_rem(&y_value); - - let config = CircuitConfig::standard_recursion_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.constant_biguint(&x_value); - let y = builder.constant_biguint(&y_value); - let (div, rem) = builder.div_rem_biguint(&x, &y); - - let expected_div = builder.constant_biguint(&expected_div_value); - let expected_rem = builder.constant_biguint(&expected_rem_value); - - builder.connect_biguint(&div, &expected_div); - builder.connect_biguint(&rem, &expected_rem); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) - } -} diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs deleted file mode 100644 index e4e66a4e..00000000 --- a/plonky2/src/gadgets/curve.rs +++ /dev/null @@ -1,434 +0,0 @@ -use plonky2_field::extension_field::Extendable; -use plonky2_field::field_types::Field; - -use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; -use crate::gadgets::nonnative::NonNativeTarget; -use crate::hash::hash_types::RichField; -use crate::iop::target::BoolTarget; -use crate::plonk::circuit_builder::CircuitBuilder; - -/// A Target representing an affine point on the curve `C`. We use incomplete arithmetic for efficiency, -/// so we assume these points are not zero. -#[derive(Clone, Debug)] -pub struct AffinePointTarget { - pub x: NonNativeTarget, - pub y: NonNativeTarget, -} - -impl AffinePointTarget { - pub fn to_vec(&self) -> Vec> { - vec![self.x.clone(), self.y.clone()] - } -} - -impl, const D: usize> CircuitBuilder { - pub fn constant_affine_point( - &mut self, - point: AffinePoint, - ) -> AffinePointTarget { - debug_assert!(!point.zero); - AffinePointTarget { - x: self.constant_nonnative(point.x), - y: self.constant_nonnative(point.y), - } - } - - pub fn connect_affine_point( - &mut self, - lhs: &AffinePointTarget, - rhs: &AffinePointTarget, - ) { - self.connect_nonnative(&lhs.x, &rhs.x); - self.connect_nonnative(&lhs.y, &rhs.y); - } - - pub fn add_virtual_affine_point_target(&mut self) -> AffinePointTarget { - let x = self.add_virtual_nonnative_target(); - let y = self.add_virtual_nonnative_target(); - - AffinePointTarget { x, y } - } - - pub fn curve_assert_valid(&mut self, p: &AffinePointTarget) { - let a = self.constant_nonnative(C::A); - let b = self.constant_nonnative(C::B); - - let y_squared = self.mul_nonnative(&p.y, &p.y); - let x_squared = self.mul_nonnative(&p.x, &p.x); - let x_cubed = self.mul_nonnative(&x_squared, &p.x); - let a_x = self.mul_nonnative(&a, &p.x); - let a_x_plus_b = self.add_nonnative(&a_x, &b); - let rhs = self.add_nonnative(&x_cubed, &a_x_plus_b); - - self.connect_nonnative(&y_squared, &rhs); - } - - pub fn curve_neg(&mut self, p: &AffinePointTarget) -> AffinePointTarget { - let neg_y = self.neg_nonnative(&p.y); - AffinePointTarget { - x: p.x.clone(), - y: neg_y, - } - } - - pub fn curve_conditional_neg( - &mut self, - p: &AffinePointTarget, - b: BoolTarget, - ) -> AffinePointTarget { - AffinePointTarget { - x: p.x.clone(), - y: self.nonnative_conditional_neg(&p.y, b), - } - } - - pub fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget { - let AffinePointTarget { x, y } = p; - let double_y = self.add_nonnative(y, y); - let inv_double_y = self.inv_nonnative(&double_y); - let x_squared = self.mul_nonnative(x, x); - let double_x_squared = self.add_nonnative(&x_squared, &x_squared); - let triple_x_squared = self.add_nonnative(&double_x_squared, &x_squared); - - let a = self.constant_nonnative(C::A); - let triple_xx_a = self.add_nonnative(&triple_x_squared, &a); - let lambda = self.mul_nonnative(&triple_xx_a, &inv_double_y); - let lambda_squared = self.mul_nonnative(&lambda, &lambda); - let x_double = self.add_nonnative(x, x); - - let x3 = self.sub_nonnative(&lambda_squared, &x_double); - - let x_diff = self.sub_nonnative(x, &x3); - let lambda_x_diff = self.mul_nonnative(&lambda, &x_diff); - - let y3 = self.sub_nonnative(&lambda_x_diff, y); - - AffinePointTarget { x: x3, y: y3 } - } - - pub fn curve_repeated_double( - &mut self, - p: &AffinePointTarget, - n: usize, - ) -> AffinePointTarget { - let mut result = p.clone(); - - for _ in 0..n { - result = self.curve_double(&result); - } - - result - } - - // Add two points, which are assumed to be non-equal. - pub fn curve_add( - &mut self, - p1: &AffinePointTarget, - p2: &AffinePointTarget, - ) -> AffinePointTarget { - let AffinePointTarget { x: x1, y: y1 } = p1; - let AffinePointTarget { x: x2, y: y2 } = p2; - - let u = self.sub_nonnative(y2, y1); - let v = self.sub_nonnative(x2, x1); - let v_inv = self.inv_nonnative(&v); - let s = self.mul_nonnative(&u, &v_inv); - let s_squared = self.mul_nonnative(&s, &s); - let x_sum = self.add_nonnative(x2, x1); - let x3 = self.sub_nonnative(&s_squared, &x_sum); - let x_diff = self.sub_nonnative(x1, &x3); - let prod = self.mul_nonnative(&s, &x_diff); - let y3 = self.sub_nonnative(&prod, y1); - - AffinePointTarget { x: x3, y: y3 } - } - - pub fn curve_conditional_add( - &mut self, - p1: &AffinePointTarget, - p2: &AffinePointTarget, - b: BoolTarget, - ) -> AffinePointTarget { - let not_b = self.not(b); - let sum = self.curve_add(p1, p2); - let x_if_true = self.mul_nonnative_by_bool(&sum.x, b); - let y_if_true = self.mul_nonnative_by_bool(&sum.y, b); - let x_if_false = self.mul_nonnative_by_bool(&p1.x, not_b); - let y_if_false = self.mul_nonnative_by_bool(&p1.y, not_b); - - let x = self.add_nonnative(&x_if_true, &x_if_false); - let y = self.add_nonnative(&y_if_true, &y_if_false); - - AffinePointTarget { x, y } - } - - pub fn curve_scalar_mul( - &mut self, - p: &AffinePointTarget, - n: &NonNativeTarget, - ) -> AffinePointTarget { - let bits = self.split_nonnative_to_bits(n); - - let rando = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine(); - let randot = self.constant_affine_point(rando); - // Result starts at `rando`, which is later subtracted, because we don't support arithmetic with the zero point. - let mut result = self.add_virtual_affine_point_target(); - self.connect_affine_point(&randot, &result); - - let mut two_i_times_p = self.add_virtual_affine_point_target(); - self.connect_affine_point(p, &two_i_times_p); - - for &bit in bits.iter() { - let not_bit = self.not(bit); - - let result_plus_2_i_p = self.curve_add(&result, &two_i_times_p); - - let new_x_if_bit = self.mul_nonnative_by_bool(&result_plus_2_i_p.x, bit); - let new_x_if_not_bit = self.mul_nonnative_by_bool(&result.x, not_bit); - let new_y_if_bit = self.mul_nonnative_by_bool(&result_plus_2_i_p.y, bit); - let new_y_if_not_bit = self.mul_nonnative_by_bool(&result.y, not_bit); - - let new_x = self.add_nonnative(&new_x_if_bit, &new_x_if_not_bit); - let new_y = self.add_nonnative(&new_y_if_bit, &new_y_if_not_bit); - - result = AffinePointTarget { x: new_x, y: new_y }; - - two_i_times_p = self.curve_double(&two_i_times_p); - } - - // Subtract off result's intial value of `rando`. - let neg_r = self.curve_neg(&randot); - result = self.curve_add(&result, &neg_r); - - result - } -} - -#[cfg(test)] -mod tests { - use std::ops::Neg; - - use anyhow::Result; - use plonky2_field::field_types::Field; - use plonky2_field::secp256k1_base::Secp256K1Base; - use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - - use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; - use crate::curve::secp256k1::Secp256K1; - use crate::iop::witness::PartialWitness; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::plonk::verifier::verify; - - #[test] - fn test_curve_point_is_valid() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let g = Secp256K1::GENERATOR_AFFINE; - let g_target = builder.constant_affine_point(g); - let neg_g_target = builder.curve_neg(&g_target); - - builder.curve_assert_valid(&g_target); - builder.curve_assert_valid(&neg_g_target); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - #[should_panic] - fn test_curve_point_is_not_valid() { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let g = Secp256K1::GENERATOR_AFFINE; - let not_g = AffinePoint:: { - x: g.x, - y: g.y + Secp256K1Base::ONE, - zero: g.zero, - }; - let not_g_target = builder.constant_affine_point(not_g); - - builder.curve_assert_valid(¬_g_target); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common).unwrap(); - } - - #[test] - fn test_curve_double() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let g = Secp256K1::GENERATOR_AFFINE; - let g_target = builder.constant_affine_point(g); - let neg_g_target = builder.curve_neg(&g_target); - - let double_g = g.double(); - let double_g_expected = builder.constant_affine_point(double_g); - builder.curve_assert_valid(&double_g_expected); - - let double_neg_g = (-g).double(); - let double_neg_g_expected = builder.constant_affine_point(double_neg_g); - builder.curve_assert_valid(&double_neg_g_expected); - - let double_g_actual = builder.curve_double(&g_target); - let double_neg_g_actual = builder.curve_double(&neg_g_target); - builder.curve_assert_valid(&double_g_actual); - builder.curve_assert_valid(&double_neg_g_actual); - - builder.connect_affine_point(&double_g_expected, &double_g_actual); - builder.connect_affine_point(&double_neg_g_expected, &double_neg_g_actual); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - fn test_curve_add() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let g = Secp256K1::GENERATOR_AFFINE; - let double_g = g.double(); - let g_plus_2g = (g + double_g).to_affine(); - let g_plus_2g_expected = builder.constant_affine_point(g_plus_2g); - builder.curve_assert_valid(&g_plus_2g_expected); - - let g_target = builder.constant_affine_point(g); - let double_g_target = builder.curve_double(&g_target); - let g_plus_2g_actual = builder.curve_add(&g_target, &double_g_target); - builder.curve_assert_valid(&g_plus_2g_actual); - - builder.connect_affine_point(&g_plus_2g_expected, &g_plus_2g_actual); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - fn test_curve_conditional_add() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let g = Secp256K1::GENERATOR_AFFINE; - let double_g = g.double(); - let g_plus_2g = (g + double_g).to_affine(); - let g_plus_2g_expected = builder.constant_affine_point(g_plus_2g); - - let g_expected = builder.constant_affine_point(g); - let double_g_target = builder.curve_double(&g_expected); - let t = builder._true(); - let f = builder._false(); - let g_plus_2g_actual = builder.curve_conditional_add(&g_expected, &double_g_target, t); - let g_actual = builder.curve_conditional_add(&g_expected, &double_g_target, f); - - builder.connect_affine_point(&g_plus_2g_expected, &g_plus_2g_actual); - builder.connect_affine_point(&g_expected, &g_actual); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - #[ignore] - fn test_curve_mul() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let g = Secp256K1::GENERATOR_PROJECTIVE.to_affine(); - let five = Secp256K1Scalar::from_canonical_usize(5); - let neg_five = five.neg(); - let neg_five_scalar = CurveScalar::(neg_five); - let neg_five_g = (neg_five_scalar * g.to_projective()).to_affine(); - let neg_five_g_expected = builder.constant_affine_point(neg_five_g); - builder.curve_assert_valid(&neg_five_g_expected); - - let g_target = builder.constant_affine_point(g); - let neg_five_target = builder.constant_nonnative(neg_five); - let neg_five_g_actual = builder.curve_scalar_mul(&g_target, &neg_five_target); - builder.curve_assert_valid(&neg_five_g_actual); - - builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - #[ignore] - fn test_curve_random() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let rando = - (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); - let randot = builder.constant_affine_point(rando); - - let two_target = builder.constant_nonnative(Secp256K1Scalar::TWO); - let randot_doubled = builder.curve_double(&randot); - let randot_times_two = builder.curve_scalar_mul(&randot, &two_target); - builder.connect_affine_point(&randot_doubled, &randot_times_two); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common) - } -} diff --git a/plonky2/src/gadgets/curve_fixed_base.rs b/plonky2/src/gadgets/curve_fixed_base.rs deleted file mode 100644 index f28e45d1..00000000 --- a/plonky2/src/gadgets/curve_fixed_base.rs +++ /dev/null @@ -1,110 +0,0 @@ -use num::BigUint; -use plonky2_field::extension_field::Extendable; - -use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; -use crate::field::field_types::Field; -use crate::gadgets::curve::AffinePointTarget; -use crate::gadgets::nonnative::NonNativeTarget; -use crate::hash::hash_types::RichField; -use crate::hash::keccak::KeccakHash; -use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::config::{GenericHashOut, Hasher}; - -impl, const D: usize> CircuitBuilder { - /// Compute windowed fixed-base scalar multiplication, using a 4-bit window. - pub fn fixed_base_curve_mul( - &mut self, - base: AffinePoint, - scalar: &NonNativeTarget, - ) -> AffinePointTarget { - // Holds `(16^i) * base` for `i=0..scalar.value.limbs.len() * 8`. - let scaled_base = (0..scalar.value.limbs.len() * 8).scan(base, |acc, _| { - let tmp = *acc; - for _ in 0..4 { - *acc = acc.double(); - } - Some(tmp) - }); - - let limbs = self.split_nonnative_to_4_bit_limbs(scalar); - - let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); - let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( - &GenericHashOut::::to_bytes(&hash_0), - )); - let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); - - let zero = self.zero(); - let mut result = self.constant_affine_point(rando); - // `s * P = sum s_i * P_i` with `P_i = (16^i) * P` and `s = sum s_i * (16^i)`. - for (limb, point) in limbs.into_iter().zip(scaled_base) { - // `muls_point[t] = t * P_i` for `t=0..16`. - let muls_point = (0..16) - .scan(AffinePoint::ZERO, |acc, _| { - let tmp = *acc; - *acc = (point + *acc).to_affine(); - Some(tmp) - }) - .map(|p| self.constant_affine_point(p)) - .collect::>(); - let is_zero = self.is_equal(limb, zero); - let should_add = self.not(is_zero); - // `r = s_i * P_i` - let r = self.random_access_curve_points(limb, muls_point); - result = self.curve_conditional_add(&result, &r, should_add); - } - - let to_add = self.constant_affine_point(-rando); - self.curve_add(&result, &to_add) - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2_field::field_types::PrimeField; - use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - - use crate::curve::curve_types::{Curve, CurveScalar}; - use crate::curve::secp256k1::Secp256K1; - use crate::field::field_types::Field; - use crate::iop::witness::{PartialWitness, Witness}; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::plonk::verifier::verify; - - #[test] - #[ignore] - fn test_fixed_base() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let mut pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let g = Secp256K1::GENERATOR_AFFINE; - let n = Secp256K1Scalar::rand(); - - let res = (CurveScalar(n) * g.to_projective()).to_affine(); - let res_expected = builder.constant_affine_point(res); - builder.curve_assert_valid(&res_expected); - - let n_target = builder.add_virtual_nonnative_target::(); - pw.set_biguint_target(&n_target.value, &n.to_canonical_biguint()); - - let res_target = builder.fixed_base_curve_mul(g, &n_target); - builder.curve_assert_valid(&res_target); - - builder.connect_affine_point(&res_target, &res_expected); - - dbg!(builder.num_gates()); - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common) - } -} diff --git a/plonky2/src/gadgets/curve_msm.rs b/plonky2/src/gadgets/curve_msm.rs deleted file mode 100644 index fba7c229..00000000 --- a/plonky2/src/gadgets/curve_msm.rs +++ /dev/null @@ -1,133 +0,0 @@ -use num::BigUint; -use plonky2_field::extension_field::Extendable; - -use crate::curve::curve_types::{Curve, CurveScalar}; -use crate::field::field_types::Field; -use crate::gadgets::curve::AffinePointTarget; -use crate::gadgets::nonnative::NonNativeTarget; -use crate::hash::hash_types::RichField; -use crate::hash::keccak::KeccakHash; -use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::config::{GenericHashOut, Hasher}; - -impl, const D: usize> CircuitBuilder { - /// Computes `n*p + m*q` using windowed MSM, with a 2-bit window. - /// See Algorithm 9.23 in Handbook of Elliptic and Hyperelliptic Curve Cryptography for a - /// description. - /// Note: Doesn't work if `p == q`. - pub fn curve_msm( - &mut self, - p: &AffinePointTarget, - q: &AffinePointTarget, - n: &NonNativeTarget, - m: &NonNativeTarget, - ) -> AffinePointTarget { - let limbs_n = self.split_nonnative_to_2_bit_limbs(n); - let limbs_m = self.split_nonnative_to_2_bit_limbs(m); - assert_eq!(limbs_n.len(), limbs_m.len()); - let num_limbs = limbs_n.len(); - - let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); - let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( - &GenericHashOut::::to_bytes(&hash_0), - )); - let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); - let rando_t = self.constant_affine_point(rando); - let neg_rando = self.constant_affine_point(-rando); - - // Precomputes `precomputation[i + 4*j] = i*p + j*q` for `i,j=0..4`. - let mut precomputation = vec![p.clone(); 16]; - let mut cur_p = rando_t.clone(); - let mut cur_q = rando_t.clone(); - for i in 0..4 { - precomputation[i] = cur_p.clone(); - precomputation[4 * i] = cur_q.clone(); - cur_p = self.curve_add(&cur_p, p); - cur_q = self.curve_add(&cur_q, q); - } - for i in 1..4 { - precomputation[i] = self.curve_add(&precomputation[i], &neg_rando); - precomputation[4 * i] = self.curve_add(&precomputation[4 * i], &neg_rando); - } - for i in 1..4 { - for j in 1..4 { - precomputation[i + 4 * j] = - self.curve_add(&precomputation[i], &precomputation[4 * j]); - } - } - - let four = self.constant(F::from_canonical_usize(4)); - - let zero = self.zero(); - let mut result = rando_t; - for (limb_n, limb_m) in limbs_n.into_iter().zip(limbs_m).rev() { - result = self.curve_repeated_double(&result, 2); - let index = self.mul_add(four, limb_m, limb_n); - let r = self.random_access_curve_points(index, precomputation.clone()); - let is_zero = self.is_equal(index, zero); - let should_add = self.not(is_zero); - result = self.curve_conditional_add(&result, &r, should_add); - } - let starting_point_multiplied = (0..2 * num_limbs).fold(rando, |acc, _| acc.double()); - let to_add = self.constant_affine_point(-starting_point_multiplied); - result = self.curve_add(&result, &to_add); - - result - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - - use crate::curve::curve_types::{Curve, CurveScalar}; - use crate::curve::secp256k1::Secp256K1; - use crate::field::field_types::Field; - use crate::iop::witness::PartialWitness; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::plonk::verifier::verify; - - #[test] - #[ignore] - fn test_curve_msm() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let p = - (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); - let q = - (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); - let n = Secp256K1Scalar::rand(); - let m = Secp256K1Scalar::rand(); - - let res = - (CurveScalar(n) * p.to_projective() + CurveScalar(m) * q.to_projective()).to_affine(); - let res_expected = builder.constant_affine_point(res); - builder.curve_assert_valid(&res_expected); - - let p_target = builder.constant_affine_point(p); - let q_target = builder.constant_affine_point(q); - let n_target = builder.constant_nonnative(n); - let m_target = builder.constant_nonnative(m); - - let res_target = builder.curve_msm(&p_target, &q_target, &n_target, &m_target); - builder.curve_assert_valid(&res_target); - - builder.connect_affine_point(&res_target, &res_expected); - - dbg!(builder.num_gates()); - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common) - } -} diff --git a/plonky2/src/gadgets/curve_windowed_mul.rs b/plonky2/src/gadgets/curve_windowed_mul.rs deleted file mode 100644 index 46c663a0..00000000 --- a/plonky2/src/gadgets/curve_windowed_mul.rs +++ /dev/null @@ -1,224 +0,0 @@ -use std::marker::PhantomData; - -use num::BigUint; -use plonky2_field::extension_field::Extendable; -use plonky2_field::field_types::Field; - -use crate::curve::curve_types::{Curve, CurveScalar}; -use crate::gadgets::arithmetic_u32::U32Target; -use crate::gadgets::biguint::BigUintTarget; -use crate::gadgets::curve::AffinePointTarget; -use crate::gadgets::nonnative::NonNativeTarget; -use crate::hash::hash_types::RichField; -use crate::hash::keccak::KeccakHash; -use crate::iop::target::{BoolTarget, Target}; -use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::config::{GenericHashOut, Hasher}; - -const WINDOW_SIZE: usize = 4; - -impl, const D: usize> CircuitBuilder { - pub fn precompute_window( - &mut self, - p: &AffinePointTarget, - ) -> Vec> { - let g = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine(); - let neg = { - let mut neg = g; - neg.y = -neg.y; - self.constant_affine_point(neg) - }; - - let mut multiples = vec![self.constant_affine_point(g)]; - for i in 1..1 << WINDOW_SIZE { - multiples.push(self.curve_add(p, &multiples[i - 1])); - } - for i in 1..1 << WINDOW_SIZE { - multiples[i] = self.curve_add(&neg, &multiples[i]); - } - multiples - } - - pub fn random_access_curve_points( - &mut self, - access_index: Target, - v: Vec>, - ) -> AffinePointTarget { - let num_limbs = C::BaseField::BITS / 32; - let zero = self.zero_u32(); - let x_limbs: Vec> = (0..num_limbs) - .map(|i| { - v.iter() - .map(|p| p.x.value.limbs.get(i).unwrap_or(&zero).0) - .collect() - }) - .collect(); - let y_limbs: Vec> = (0..num_limbs) - .map(|i| { - v.iter() - .map(|p| p.y.value.limbs.get(i).unwrap_or(&zero).0) - .collect() - }) - .collect(); - - let selected_x_limbs: Vec<_> = x_limbs - .iter() - .map(|limbs| U32Target(self.random_access(access_index, limbs.clone()))) - .collect(); - let selected_y_limbs: Vec<_> = y_limbs - .iter() - .map(|limbs| U32Target(self.random_access(access_index, limbs.clone()))) - .collect(); - - let x = NonNativeTarget { - value: BigUintTarget { - limbs: selected_x_limbs, - }, - _phantom: PhantomData, - }; - let y = NonNativeTarget { - value: BigUintTarget { - limbs: selected_y_limbs, - }, - _phantom: PhantomData, - }; - AffinePointTarget { x, y } - } - - pub fn if_affine_point( - &mut self, - b: BoolTarget, - p1: &AffinePointTarget, - p2: &AffinePointTarget, - ) -> AffinePointTarget { - let new_x = self.if_nonnative(b, &p1.x, &p2.x); - let new_y = self.if_nonnative(b, &p1.y, &p2.y); - AffinePointTarget { x: new_x, y: new_y } - } - - pub fn curve_scalar_mul_windowed( - &mut self, - p: &AffinePointTarget, - n: &NonNativeTarget, - ) -> AffinePointTarget { - let hash_0 = KeccakHash::<25>::hash_no_pad(&[F::ZERO]); - let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( - &GenericHashOut::::to_bytes(&hash_0), - )); - let starting_point = CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE; - let starting_point_multiplied = { - let mut cur = starting_point; - for _ in 0..C::ScalarField::BITS { - cur = cur.double(); - } - cur - }; - - let mut result = self.constant_affine_point(starting_point.to_affine()); - - let precomputation = self.precompute_window(p); - let zero = self.zero(); - - let windows = self.split_nonnative_to_4_bit_limbs(n); - for i in (0..windows.len()).rev() { - result = self.curve_repeated_double(&result, WINDOW_SIZE); - let window = windows[i]; - - let to_add = self.random_access_curve_points(window, precomputation.clone()); - let is_zero = self.is_equal(window, zero); - let should_add = self.not(is_zero); - result = self.curve_conditional_add(&result, &to_add, should_add); - } - - let to_subtract = self.constant_affine_point(starting_point_multiplied.to_affine()); - let to_add = self.curve_neg(&to_subtract); - result = self.curve_add(&result, &to_add); - - result - } -} - -#[cfg(test)] -mod tests { - use std::ops::Neg; - - use anyhow::Result; - use plonky2_field::field_types::Field; - use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - use rand::Rng; - - use crate::curve::curve_types::{Curve, CurveScalar}; - use crate::curve::secp256k1::Secp256K1; - use crate::iop::witness::PartialWitness; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::plonk::verifier::verify; - - #[test] - fn test_random_access_curve_points() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let num_points = 16; - let points: Vec<_> = (0..num_points) - .map(|_| { - let g = (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE) - .to_affine(); - builder.constant_affine_point(g) - }) - .collect(); - - let mut rng = rand::thread_rng(); - let access_index = rng.gen::() % num_points; - - let access_index_target = builder.constant(F::from_canonical_usize(access_index)); - let selected = builder.random_access_curve_points(access_index_target, points.clone()); - let expected = points[access_index].clone(); - builder.connect_affine_point(&selected, &expected); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - fn test_curve_windowed_mul() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let g = - (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); - let five = Secp256K1Scalar::from_canonical_usize(5); - let neg_five = five.neg(); - let neg_five_scalar = CurveScalar::(neg_five); - let neg_five_g = (neg_five_scalar * g.to_projective()).to_affine(); - let neg_five_g_expected = builder.constant_affine_point(neg_five_g); - builder.curve_assert_valid(&neg_five_g_expected); - - let g_target = builder.constant_affine_point(g); - let neg_five_target = builder.constant_nonnative(neg_five); - let neg_five_g_actual = builder.curve_scalar_mul_windowed(&g_target, &neg_five_target); - builder.curve_assert_valid(&neg_five_g_actual); - - builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common) - } -} diff --git a/plonky2/src/gadgets/ecdsa.rs b/plonky2/src/gadgets/ecdsa.rs deleted file mode 100644 index a376e56a..00000000 --- a/plonky2/src/gadgets/ecdsa.rs +++ /dev/null @@ -1,115 +0,0 @@ -use std::marker::PhantomData; - -use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - -use crate::curve::curve_types::Curve; -use crate::curve::secp256k1::Secp256K1; -use crate::field::extension_field::Extendable; -use crate::gadgets::curve::AffinePointTarget; -use crate::gadgets::nonnative::NonNativeTarget; -use crate::hash::hash_types::RichField; -use crate::plonk::circuit_builder::CircuitBuilder; - -#[derive(Clone, Debug)] -pub struct ECDSASecretKeyTarget(NonNativeTarget); - -#[derive(Clone, Debug)] -pub struct ECDSAPublicKeyTarget(AffinePointTarget); - -#[derive(Clone, Debug)] -pub struct ECDSASignatureTarget { - pub r: NonNativeTarget, - pub s: NonNativeTarget, -} - -impl, const D: usize> CircuitBuilder { - pub fn verify_message( - &mut self, - msg: NonNativeTarget, - sig: ECDSASignatureTarget, - pk: ECDSAPublicKeyTarget, - ) { - let ECDSASignatureTarget { r, s } = sig; - - self.curve_assert_valid(&pk.0); - - let c = self.inv_nonnative(&s); - let u1 = self.mul_nonnative(&msg, &c); - let u2 = self.mul_nonnative(&r, &c); - - let point1 = self.fixed_base_curve_mul(Secp256K1::GENERATOR_AFFINE, &u1); - let point2 = self.glv_mul(&pk.0, &u2); - let point = self.curve_add(&point1, &point2); - - let x = NonNativeTarget:: { - value: point.x.value, - _phantom: PhantomData, - }; - self.connect_nonnative(&r, &x); - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - - use crate::curve::curve_types::{Curve, CurveScalar}; - use crate::curve::ecdsa::{sign_message, ECDSAPublicKey, ECDSASecretKey, ECDSASignature}; - use crate::curve::secp256k1::Secp256K1; - use crate::field::field_types::Field; - use crate::field::secp256k1_scalar::Secp256K1Scalar; - use crate::gadgets::ecdsa::{ECDSAPublicKeyTarget, ECDSASignatureTarget}; - use crate::iop::witness::PartialWitness; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::plonk::verifier::verify; - - fn test_ecdsa_circuit_with_config(config: CircuitConfig) -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - type Curve = Secp256K1; - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let msg = Secp256K1Scalar::rand(); - let msg_target = builder.constant_nonnative(msg); - - let sk = ECDSASecretKey::(Secp256K1Scalar::rand()); - let pk = ECDSAPublicKey((CurveScalar(sk.0) * Curve::GENERATOR_PROJECTIVE).to_affine()); - - let pk_target = ECDSAPublicKeyTarget(builder.constant_affine_point(pk.0)); - - let sig = sign_message(msg, sk); - - let ECDSASignature { r, s } = sig; - let r_target = builder.constant_nonnative(r); - let s_target = builder.constant_nonnative(s); - let sig_target = ECDSASignatureTarget { - r: r_target, - s: s_target, - }; - - builder.verify_message(msg_target, sig_target, pk_target); - - dbg!(builder.num_gates()); - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - #[ignore] - fn test_ecdsa_circuit_narrow() -> Result<()> { - test_ecdsa_circuit_with_config(CircuitConfig::standard_ecc_config()) - } - - #[test] - #[ignore] - fn test_ecdsa_circuit_wide() -> Result<()> { - test_ecdsa_circuit_with_config(CircuitConfig::wide_ecc_config()) - } -} diff --git a/plonky2/src/gadgets/glv.rs b/plonky2/src/gadgets/glv.rs deleted file mode 100644 index 8a0179ec..00000000 --- a/plonky2/src/gadgets/glv.rs +++ /dev/null @@ -1,148 +0,0 @@ -use std::marker::PhantomData; - -use plonky2_field::extension_field::Extendable; -use plonky2_field::secp256k1_base::Secp256K1Base; -use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - -use crate::curve::glv::{decompose_secp256k1_scalar, GLV_BETA, GLV_S}; -use crate::curve::secp256k1::Secp256K1; -use crate::gadgets::curve::AffinePointTarget; -use crate::gadgets::nonnative::NonNativeTarget; -use crate::hash::hash_types::RichField; -use crate::iop::generator::{GeneratedValues, SimpleGenerator}; -use crate::iop::target::{BoolTarget, Target}; -use crate::iop::witness::{PartitionWitness, Witness}; -use crate::plonk::circuit_builder::CircuitBuilder; - -impl, const D: usize> CircuitBuilder { - pub fn secp256k1_glv_beta(&mut self) -> NonNativeTarget { - self.constant_nonnative(GLV_BETA) - } - - pub fn decompose_secp256k1_scalar( - &mut self, - k: &NonNativeTarget, - ) -> ( - NonNativeTarget, - NonNativeTarget, - BoolTarget, - BoolTarget, - ) { - let k1 = self.add_virtual_nonnative_target_sized::(4); - let k2 = self.add_virtual_nonnative_target_sized::(4); - let k1_neg = self.add_virtual_bool_target(); - let k2_neg = self.add_virtual_bool_target(); - - self.add_simple_generator(GLVDecompositionGenerator:: { - k: k.clone(), - k1: k1.clone(), - k2: k2.clone(), - k1_neg, - k2_neg, - _phantom: PhantomData, - }); - - // Check that `k1_raw + GLV_S * k2_raw == k`. - let k1_raw = self.nonnative_conditional_neg(&k1, k1_neg); - let k2_raw = self.nonnative_conditional_neg(&k2, k2_neg); - let s = self.constant_nonnative(GLV_S); - let mut should_be_k = self.mul_nonnative(&s, &k2_raw); - should_be_k = self.add_nonnative(&should_be_k, &k1_raw); - self.connect_nonnative(&should_be_k, k); - - (k1, k2, k1_neg, k2_neg) - } - - pub fn glv_mul( - &mut self, - p: &AffinePointTarget, - k: &NonNativeTarget, - ) -> AffinePointTarget { - let (k1, k2, k1_neg, k2_neg) = self.decompose_secp256k1_scalar(k); - - let beta = self.secp256k1_glv_beta(); - let beta_px = self.mul_nonnative(&beta, &p.x); - let sp = AffinePointTarget:: { - x: beta_px, - y: p.y.clone(), - }; - - let p_neg = self.curve_conditional_neg(p, k1_neg); - let sp_neg = self.curve_conditional_neg(&sp, k2_neg); - self.curve_msm(&p_neg, &sp_neg, &k1, &k2) - } -} - -#[derive(Debug)] -struct GLVDecompositionGenerator, const D: usize> { - k: NonNativeTarget, - k1: NonNativeTarget, - k2: NonNativeTarget, - k1_neg: BoolTarget, - k2_neg: BoolTarget, - _phantom: PhantomData, -} - -impl, const D: usize> SimpleGenerator - for GLVDecompositionGenerator -{ - fn dependencies(&self) -> Vec { - self.k.value.limbs.iter().map(|l| l.0).collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let k = witness.get_nonnative_target(self.k.clone()); - let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); - - out_buffer.set_nonnative_target(self.k1.clone(), k1); - out_buffer.set_nonnative_target(self.k2.clone(), k2); - out_buffer.set_bool_target(self.k1_neg, k1_neg); - out_buffer.set_bool_target(self.k2_neg, k2_neg); - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2_field::field_types::Field; - use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - - use crate::curve::curve_types::{Curve, CurveScalar}; - use crate::curve::glv::glv_mul; - use crate::curve::secp256k1::Secp256K1; - use crate::iop::witness::PartialWitness; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::plonk::verifier::verify; - - #[test] - fn test_glv_gadget() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let rando = - (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); - let randot = builder.constant_affine_point(rando); - - let scalar = Secp256K1Scalar::rand(); - let scalar_target = builder.constant_nonnative(scalar); - - let rando_glv_scalar = glv_mul(rando.to_projective(), scalar); - let expected = builder.constant_affine_point(rando_glv_scalar.to_affine()); - let actual = builder.glv_mul(&randot, &scalar_target); - builder.connect_affine_point(&expected, &actual); - - dbg!(builder.num_gates()); - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - verify(proof, &data.verifier_only, &data.common) - } -} diff --git a/plonky2/src/gadgets/mod.rs b/plonky2/src/gadgets/mod.rs index 50fc0437..d8613337 100644 --- a/plonky2/src/gadgets/mod.rs +++ b/plonky2/src/gadgets/mod.rs @@ -1,21 +1,12 @@ pub mod arithmetic; pub mod arithmetic_extension; pub mod arithmetic_u32; -pub mod biguint; -pub mod curve; -pub mod curve_fixed_base; -pub mod curve_msm; -pub mod curve_windowed_mul; -pub mod ecdsa; -pub mod glv; pub mod hash; pub mod interpolation; pub mod multiple_comparison; -pub mod nonnative; pub mod polynomial; pub mod random_access; pub mod range_check; pub mod select; pub mod split_base; pub(crate) mod split_join; -pub mod split_nonnative; diff --git a/plonky2/src/gadgets/nonnative.rs b/plonky2/src/gadgets/nonnative.rs deleted file mode 100644 index 6c483a86..00000000 --- a/plonky2/src/gadgets/nonnative.rs +++ /dev/null @@ -1,732 +0,0 @@ -use std::marker::PhantomData; - -use num::{BigUint, Integer, One, Zero}; -use plonky2_field::field_types::PrimeField; -use plonky2_field::{extension_field::Extendable, field_types::Field}; -use plonky2_util::ceil_div_usize; - -use crate::gadgets::arithmetic_u32::U32Target; -use crate::gadgets::biguint::BigUintTarget; -use crate::hash::hash_types::RichField; -use crate::iop::generator::{GeneratedValues, SimpleGenerator}; -use crate::iop::target::{BoolTarget, Target}; -use crate::iop::witness::{PartitionWitness, Witness}; -use crate::plonk::circuit_builder::CircuitBuilder; - -#[derive(Clone, Debug)] -pub struct NonNativeTarget { - pub(crate) value: BigUintTarget, - pub(crate) _phantom: PhantomData, -} - -impl, const D: usize> CircuitBuilder { - fn num_nonnative_limbs() -> usize { - ceil_div_usize(FF::BITS, 32) - } - - pub fn biguint_to_nonnative(&mut self, x: &BigUintTarget) -> NonNativeTarget { - NonNativeTarget { - value: x.clone(), - _phantom: PhantomData, - } - } - - pub fn nonnative_to_canonical_biguint( - &mut self, - x: &NonNativeTarget, - ) -> BigUintTarget { - x.value.clone() - } - - pub fn constant_nonnative(&mut self, x: FF) -> NonNativeTarget { - let x_biguint = self.constant_biguint(&x.to_canonical_biguint()); - self.biguint_to_nonnative(&x_biguint) - } - - pub fn zero_nonnative(&mut self) -> NonNativeTarget { - self.constant_nonnative(FF::ZERO) - } - - // Assert that two NonNativeTarget's, both assumed to be in reduced form, are equal. - pub fn connect_nonnative( - &mut self, - lhs: &NonNativeTarget, - rhs: &NonNativeTarget, - ) { - self.connect_biguint(&lhs.value, &rhs.value); - } - - pub fn add_virtual_nonnative_target(&mut self) -> NonNativeTarget { - let num_limbs = Self::num_nonnative_limbs::(); - let value = self.add_virtual_biguint_target(num_limbs); - - NonNativeTarget { - value, - _phantom: PhantomData, - } - } - - pub fn add_virtual_nonnative_target_sized( - &mut self, - num_limbs: usize, - ) -> NonNativeTarget { - let value = self.add_virtual_biguint_target(num_limbs); - - NonNativeTarget { - value, - _phantom: PhantomData, - } - } - - pub fn add_nonnative( - &mut self, - a: &NonNativeTarget, - b: &NonNativeTarget, - ) -> NonNativeTarget { - let sum = self.add_virtual_nonnative_target::(); - let overflow = self.add_virtual_bool_target(); - - self.add_simple_generator(NonNativeAdditionGenerator:: { - a: a.clone(), - b: b.clone(), - sum: sum.clone(), - overflow, - _phantom: PhantomData, - }); - - let sum_expected = self.add_biguint(&a.value, &b.value); - - let modulus = self.constant_biguint(&FF::order()); - let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow); - let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow); - self.connect_biguint(&sum_expected, &sum_actual); - - // Range-check result. - // TODO: can potentially leave unreduced until necessary (e.g. when connecting values). - let cmp = self.cmp_biguint(&sum.value, &modulus); - let one = self.one(); - self.connect(cmp.target, one); - - sum - } - - pub fn mul_nonnative_by_bool( - &mut self, - a: &NonNativeTarget, - b: BoolTarget, - ) -> NonNativeTarget { - NonNativeTarget { - value: self.mul_biguint_by_bool(&a.value, b), - _phantom: PhantomData, - } - } - - pub fn if_nonnative( - &mut self, - b: BoolTarget, - x: &NonNativeTarget, - y: &NonNativeTarget, - ) -> NonNativeTarget { - let not_b = self.not(b); - let maybe_x = self.mul_nonnative_by_bool(x, b); - let maybe_y = self.mul_nonnative_by_bool(y, not_b); - self.add_nonnative(&maybe_x, &maybe_y) - } - - pub fn add_many_nonnative( - &mut self, - to_add: &[NonNativeTarget], - ) -> NonNativeTarget { - if to_add.len() == 1 { - return to_add[0].clone(); - } - - let sum = self.add_virtual_nonnative_target::(); - let overflow = self.add_virtual_u32_target(); - let summands = to_add.to_vec(); - - self.add_simple_generator(NonNativeMultipleAddsGenerator:: { - summands: summands.clone(), - sum: sum.clone(), - overflow, - _phantom: PhantomData, - }); - - self.range_check_u32(sum.value.limbs.clone()); - self.range_check_u32(vec![overflow]); - - let sum_expected = summands - .iter() - .fold(self.zero_biguint(), |a, b| self.add_biguint(&a, &b.value)); - - let modulus = self.constant_biguint(&FF::order()); - let overflow_biguint = BigUintTarget { - limbs: vec![overflow], - }; - let mod_times_overflow = self.mul_biguint(&modulus, &overflow_biguint); - let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow); - self.connect_biguint(&sum_expected, &sum_actual); - - // Range-check result. - // TODO: can potentially leave unreduced until necessary (e.g. when connecting values). - let cmp = self.cmp_biguint(&sum.value, &modulus); - let one = self.one(); - self.connect(cmp.target, one); - - sum - } - - // Subtract two `NonNativeTarget`s. - pub fn sub_nonnative( - &mut self, - a: &NonNativeTarget, - b: &NonNativeTarget, - ) -> NonNativeTarget { - let diff = self.add_virtual_nonnative_target::(); - let overflow = self.add_virtual_bool_target(); - - self.add_simple_generator(NonNativeSubtractionGenerator:: { - a: a.clone(), - b: b.clone(), - diff: diff.clone(), - overflow, - _phantom: PhantomData, - }); - - self.range_check_u32(diff.value.limbs.clone()); - self.assert_bool(overflow); - - let diff_plus_b = self.add_biguint(&diff.value, &b.value); - let modulus = self.constant_biguint(&FF::order()); - let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow); - let diff_plus_b_reduced = self.sub_biguint(&diff_plus_b, &mod_times_overflow); - self.connect_biguint(&a.value, &diff_plus_b_reduced); - - diff - } - - pub fn mul_nonnative( - &mut self, - a: &NonNativeTarget, - b: &NonNativeTarget, - ) -> NonNativeTarget { - let prod = self.add_virtual_nonnative_target::(); - let modulus = self.constant_biguint(&FF::order()); - let overflow = self.add_virtual_biguint_target( - a.value.num_limbs() + b.value.num_limbs() - modulus.num_limbs(), - ); - - self.add_simple_generator(NonNativeMultiplicationGenerator:: { - a: a.clone(), - b: b.clone(), - prod: prod.clone(), - overflow: overflow.clone(), - _phantom: PhantomData, - }); - - self.range_check_u32(prod.value.limbs.clone()); - self.range_check_u32(overflow.limbs.clone()); - - let prod_expected = self.mul_biguint(&a.value, &b.value); - - let mod_times_overflow = self.mul_biguint(&modulus, &overflow); - let prod_actual = self.add_biguint(&prod.value, &mod_times_overflow); - self.connect_biguint(&prod_expected, &prod_actual); - - prod - } - - pub fn mul_many_nonnative( - &mut self, - to_mul: &[NonNativeTarget], - ) -> NonNativeTarget { - if to_mul.len() == 1 { - return to_mul[0].clone(); - } - - let mut accumulator = self.mul_nonnative(&to_mul[0], &to_mul[1]); - for i in 2..to_mul.len() { - accumulator = self.mul_nonnative(&accumulator, &to_mul[i]); - } - accumulator - } - - pub fn neg_nonnative( - &mut self, - x: &NonNativeTarget, - ) -> NonNativeTarget { - let zero_target = self.constant_biguint(&BigUint::zero()); - let zero_ff = self.biguint_to_nonnative(&zero_target); - - self.sub_nonnative(&zero_ff, x) - } - - pub fn inv_nonnative( - &mut self, - x: &NonNativeTarget, - ) -> NonNativeTarget { - let num_limbs = x.value.num_limbs(); - let inv_biguint = self.add_virtual_biguint_target(num_limbs); - let div = self.add_virtual_biguint_target(num_limbs); - - self.add_simple_generator(NonNativeInverseGenerator:: { - x: x.clone(), - inv: inv_biguint.clone(), - div: div.clone(), - _phantom: PhantomData, - }); - - let product = self.mul_biguint(&x.value, &inv_biguint); - - let modulus = self.constant_biguint(&FF::order()); - let mod_times_div = self.mul_biguint(&modulus, &div); - let one = self.constant_biguint(&BigUint::one()); - let expected_product = self.add_biguint(&mod_times_div, &one); - self.connect_biguint(&product, &expected_product); - - NonNativeTarget:: { - value: inv_biguint, - _phantom: PhantomData, - } - } - - /// Returns `x % |FF|` as a `NonNativeTarget`. - fn reduce(&mut self, x: &BigUintTarget) -> NonNativeTarget { - let modulus = FF::order(); - let order_target = self.constant_biguint(&modulus); - let value = self.rem_biguint(x, &order_target); - - NonNativeTarget { - value, - _phantom: PhantomData, - } - } - - pub fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { - let x_biguint = self.nonnative_to_canonical_biguint(x); - self.reduce(&x_biguint) - } - - pub fn bool_to_nonnative(&mut self, b: &BoolTarget) -> NonNativeTarget { - let limbs = vec![U32Target(b.target)]; - let value = BigUintTarget { limbs }; - - NonNativeTarget { - value, - _phantom: PhantomData, - } - } - - // Split a nonnative field element to bits. - pub fn split_nonnative_to_bits( - &mut self, - x: &NonNativeTarget, - ) -> Vec { - let num_limbs = x.value.num_limbs(); - let mut result = Vec::with_capacity(num_limbs * 32); - - for i in 0..num_limbs { - let limb = x.value.get_limb(i); - let bit_targets = self.split_le_base::<2>(limb.0, 32); - let mut bits: Vec<_> = bit_targets - .iter() - .map(|&t| BoolTarget::new_unsafe(t)) - .collect(); - - result.append(&mut bits); - } - - result - } - - pub fn nonnative_conditional_neg( - &mut self, - x: &NonNativeTarget, - b: BoolTarget, - ) -> NonNativeTarget { - let not_b = self.not(b); - let neg = self.neg_nonnative(x); - let x_if_true = self.mul_nonnative_by_bool(&neg, b); - let x_if_false = self.mul_nonnative_by_bool(x, not_b); - - self.add_nonnative(&x_if_true, &x_if_false) - } -} - -#[derive(Debug)] -struct NonNativeAdditionGenerator, const D: usize, FF: PrimeField> { - a: NonNativeTarget, - b: NonNativeTarget, - sum: NonNativeTarget, - overflow: BoolTarget, - _phantom: PhantomData, -} - -impl, const D: usize, FF: PrimeField> SimpleGenerator - for NonNativeAdditionGenerator -{ - fn dependencies(&self) -> Vec { - self.a - .value - .limbs - .iter() - .cloned() - .chain(self.b.value.limbs.clone()) - .map(|l| l.0) - .collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = witness.get_nonnative_target(self.a.clone()); - let b = witness.get_nonnative_target(self.b.clone()); - let a_biguint = a.to_canonical_biguint(); - let b_biguint = b.to_canonical_biguint(); - let sum_biguint = a_biguint + b_biguint; - let modulus = FF::order(); - let (overflow, sum_reduced) = if sum_biguint > modulus { - (true, sum_biguint - modulus) - } else { - (false, sum_biguint) - }; - - out_buffer.set_biguint_target(self.sum.value.clone(), sum_reduced); - out_buffer.set_bool_target(self.overflow, overflow); - } -} - -#[derive(Debug)] -struct NonNativeMultipleAddsGenerator, const D: usize, FF: PrimeField> -{ - summands: Vec>, - sum: NonNativeTarget, - overflow: U32Target, - _phantom: PhantomData, -} - -impl, const D: usize, FF: PrimeField> SimpleGenerator - for NonNativeMultipleAddsGenerator -{ - fn dependencies(&self) -> Vec { - self.summands - .iter() - .flat_map(|summand| summand.value.limbs.iter().map(|limb| limb.0)) - .collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let summands: Vec<_> = self - .summands - .iter() - .map(|summand| witness.get_nonnative_target(summand.clone())) - .collect(); - let summand_biguints: Vec<_> = summands - .iter() - .map(|summand| summand.to_canonical_biguint()) - .collect(); - - let sum_biguint = summand_biguints - .iter() - .fold(BigUint::zero(), |a, b| a + b.clone()); - - let modulus = FF::order(); - let (overflow_biguint, sum_reduced) = sum_biguint.div_rem(&modulus); - let overflow = overflow_biguint.to_u64_digits()[0] as u32; - - out_buffer.set_biguint_target(self.sum.value.clone(), sum_reduced); - out_buffer.set_u32_target(self.overflow, overflow); - } -} - -#[derive(Debug)] -struct NonNativeSubtractionGenerator, const D: usize, FF: Field> { - a: NonNativeTarget, - b: NonNativeTarget, - diff: NonNativeTarget, - overflow: BoolTarget, - _phantom: PhantomData, -} - -impl, const D: usize, FF: PrimeField> SimpleGenerator - for NonNativeSubtractionGenerator -{ - fn dependencies(&self) -> Vec { - self.a - .value - .limbs - .iter() - .cloned() - .chain(self.b.value.limbs.clone()) - .map(|l| l.0) - .collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = witness.get_nonnative_target(self.a.clone()); - let b = witness.get_nonnative_target(self.b.clone()); - let a_biguint = a.to_canonical_biguint(); - let b_biguint = b.to_canonical_biguint(); - - let modulus = FF::order(); - let (diff_biguint, overflow) = if a_biguint >= b_biguint { - (a_biguint - b_biguint, false) - } else { - (modulus + a_biguint - b_biguint, true) - }; - - out_buffer.set_biguint_target(self.diff.value.clone(), diff_biguint); - out_buffer.set_bool_target(self.overflow, overflow); - } -} - -#[derive(Debug)] -struct NonNativeMultiplicationGenerator, const D: usize, FF: Field> { - a: NonNativeTarget, - b: NonNativeTarget, - prod: NonNativeTarget, - overflow: BigUintTarget, - _phantom: PhantomData, -} - -impl, const D: usize, FF: PrimeField> SimpleGenerator - for NonNativeMultiplicationGenerator -{ - fn dependencies(&self) -> Vec { - self.a - .value - .limbs - .iter() - .cloned() - .chain(self.b.value.limbs.clone()) - .map(|l| l.0) - .collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = witness.get_nonnative_target(self.a.clone()); - let b = witness.get_nonnative_target(self.b.clone()); - let a_biguint = a.to_canonical_biguint(); - let b_biguint = b.to_canonical_biguint(); - - let prod_biguint = a_biguint * b_biguint; - - let modulus = FF::order(); - let (overflow_biguint, prod_reduced) = prod_biguint.div_rem(&modulus); - - out_buffer.set_biguint_target(self.prod.value.clone(), prod_reduced); - out_buffer.set_biguint_target(self.overflow.clone(), overflow_biguint); - } -} - -#[derive(Debug)] -struct NonNativeInverseGenerator, const D: usize, FF: PrimeField> { - x: NonNativeTarget, - inv: BigUintTarget, - div: BigUintTarget, - _phantom: PhantomData, -} - -impl, const D: usize, FF: PrimeField> SimpleGenerator - for NonNativeInverseGenerator -{ - fn dependencies(&self) -> Vec { - self.x.value.limbs.iter().map(|&l| l.0).collect() - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let x = witness.get_nonnative_target(self.x.clone()); - let inv = x.inverse(); - - let x_biguint = x.to_canonical_biguint(); - let inv_biguint = inv.to_canonical_biguint(); - let prod = x_biguint * &inv_biguint; - let modulus = FF::order(); - let (div, _rem) = prod.div_rem(&modulus); - - out_buffer.set_biguint_target(self.div.clone(), div); - out_buffer.set_biguint_target(self.inv.clone(), inv_biguint); - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2_field::field_types::{Field, PrimeField}; - use plonky2_field::secp256k1_base::Secp256K1Base; - - use crate::iop::witness::PartialWitness; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::plonk::verifier::verify; - - #[test] - fn test_nonnative_add() -> Result<()> { - type FF = Secp256K1Base; - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let x_ff = FF::rand(); - let y_ff = FF::rand(); - let sum_ff = x_ff + y_ff; - - let config = CircuitConfig::standard_ecc_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.constant_nonnative(x_ff); - let y = builder.constant_nonnative(y_ff); - let sum = builder.add_nonnative(&x, &y); - - let sum_expected = builder.constant_nonnative(sum_ff); - builder.connect_nonnative(&sum, &sum_expected); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - fn test_nonnative_many_adds() -> Result<()> { - type FF = Secp256K1Base; - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let a_ff = FF::rand(); - let b_ff = FF::rand(); - let c_ff = FF::rand(); - let d_ff = FF::rand(); - let e_ff = FF::rand(); - let f_ff = FF::rand(); - let g_ff = FF::rand(); - let h_ff = FF::rand(); - let sum_ff = a_ff + b_ff + c_ff + d_ff + e_ff + f_ff + g_ff + h_ff; - - let config = CircuitConfig::standard_ecc_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let a = builder.constant_nonnative(a_ff); - let b = builder.constant_nonnative(b_ff); - let c = builder.constant_nonnative(c_ff); - let d = builder.constant_nonnative(d_ff); - let e = builder.constant_nonnative(e_ff); - let f = builder.constant_nonnative(f_ff); - let g = builder.constant_nonnative(g_ff); - let h = builder.constant_nonnative(h_ff); - let all = [a, b, c, d, e, f, g, h]; - let sum = builder.add_many_nonnative(&all); - - let sum_expected = builder.constant_nonnative(sum_ff); - builder.connect_nonnative(&sum, &sum_expected); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - fn test_nonnative_sub() -> Result<()> { - type FF = Secp256K1Base; - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let x_ff = FF::rand(); - let mut y_ff = FF::rand(); - while y_ff.to_canonical_biguint() > x_ff.to_canonical_biguint() { - y_ff = FF::rand(); - } - let diff_ff = x_ff - y_ff; - - let config = CircuitConfig::standard_ecc_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.constant_nonnative(x_ff); - let y = builder.constant_nonnative(y_ff); - let diff = builder.sub_nonnative(&x, &y); - - let diff_expected = builder.constant_nonnative(diff_ff); - builder.connect_nonnative(&diff, &diff_expected); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - fn test_nonnative_mul() -> Result<()> { - type FF = Secp256K1Base; - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - let x_ff = FF::rand(); - let y_ff = FF::rand(); - let product_ff = x_ff * y_ff; - - let config = CircuitConfig::standard_ecc_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.constant_nonnative(x_ff); - let y = builder.constant_nonnative(y_ff); - let product = builder.mul_nonnative(&x, &y); - - let product_expected = builder.constant_nonnative(product_ff); - builder.connect_nonnative(&product, &product_expected); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - fn test_nonnative_neg() -> Result<()> { - type FF = Secp256K1Base; - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - let x_ff = FF::rand(); - let neg_x_ff = -x_ff; - - let config = CircuitConfig::standard_ecc_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.constant_nonnative(x_ff); - let neg_x = builder.neg_nonnative(&x); - - let neg_x_expected = builder.constant_nonnative(neg_x_ff); - builder.connect_nonnative(&neg_x, &neg_x_expected); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) - } - - #[test] - fn test_nonnative_inv() -> Result<()> { - type FF = Secp256K1Base; - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - let x_ff = FF::rand(); - let inv_x_ff = x_ff.inverse(); - - let config = CircuitConfig::standard_ecc_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = builder.constant_nonnative(x_ff); - let inv_x = builder.inv_nonnative(&x); - - let inv_x_expected = builder.constant_nonnative(inv_x_ff); - builder.connect_nonnative(&inv_x, &inv_x_expected); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) - } -} diff --git a/plonky2/src/gadgets/split_nonnative.rs b/plonky2/src/gadgets/split_nonnative.rs deleted file mode 100644 index 18fc0264..00000000 --- a/plonky2/src/gadgets/split_nonnative.rs +++ /dev/null @@ -1,109 +0,0 @@ -use std::marker::PhantomData; - -use itertools::Itertools; -use plonky2_field::extension_field::Extendable; -use plonky2_field::field_types::Field; - -use crate::gadgets::arithmetic_u32::U32Target; -use crate::gadgets::biguint::BigUintTarget; -use crate::gadgets::nonnative::NonNativeTarget; -use crate::hash::hash_types::RichField; -use crate::iop::target::Target; -use crate::plonk::circuit_builder::CircuitBuilder; - -impl, const D: usize> CircuitBuilder { - pub fn split_u32_to_4_bit_limbs(&mut self, val: U32Target) -> Vec { - let two_bit_limbs = self.split_le_base::<4>(val.0, 16); - let four = self.constant(F::from_canonical_usize(4)); - let combined_limbs = two_bit_limbs - .iter() - .tuples() - .map(|(&a, &b)| self.mul_add(b, four, a)) - .collect(); - - combined_limbs - } - - pub fn split_nonnative_to_4_bit_limbs( - &mut self, - val: &NonNativeTarget, - ) -> Vec { - val.value - .limbs - .iter() - .flat_map(|&l| self.split_u32_to_4_bit_limbs(l)) - .collect() - } - - pub fn split_nonnative_to_2_bit_limbs( - &mut self, - val: &NonNativeTarget, - ) -> Vec { - val.value - .limbs - .iter() - .flat_map(|&l| self.split_le_base::<4>(l.0, 16)) - .collect() - } - - // Note: assumes its inputs are 4-bit limbs, and does not range-check. - pub fn recombine_nonnative_4_bit_limbs( - &mut self, - limbs: Vec, - ) -> NonNativeTarget { - let base = self.constant_u32(1 << 4); - let u32_limbs = limbs - .chunks(8) - .map(|chunk| { - let mut combined_chunk = self.zero_u32(); - for i in (0..8).rev() { - let (low, _high) = self.mul_add_u32(combined_chunk, base, U32Target(chunk[i])); - combined_chunk = low; - } - combined_chunk - }) - .collect(); - - NonNativeTarget { - value: BigUintTarget { limbs: u32_limbs }, - _phantom: PhantomData, - } - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use plonky2_field::field_types::Field; - use plonky2_field::secp256k1_scalar::Secp256K1Scalar; - - use crate::gadgets::nonnative::NonNativeTarget; - use crate::iop::witness::PartialWitness; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::plonk::verifier::verify; - - #[test] - fn test_split_nonnative() -> Result<()> { - type FF = Secp256K1Scalar; - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let config = CircuitConfig::standard_ecc_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let x = FF::rand(); - let x_target = builder.constant_nonnative(x); - let split = builder.split_nonnative_to_4_bit_limbs(&x_target); - let combined: NonNativeTarget = - builder.recombine_nonnative_4_bit_limbs(split); - builder.connect_nonnative(&x_target, &combined); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) - } -} diff --git a/plonky2/src/iop/generator.rs b/plonky2/src/iop/generator.rs index 1569e889..f36ba3aa 100644 --- a/plonky2/src/iop/generator.rs +++ b/plonky2/src/iop/generator.rs @@ -1,13 +1,10 @@ use std::fmt::Debug; use std::marker::PhantomData; -use num::BigUint; use plonky2_field::extension_field::{Extendable, FieldExtension}; -use plonky2_field::field_types::{Field, PrimeField}; +use plonky2_field::field_types::Field; use crate::gadgets::arithmetic_u32::U32Target; -use crate::gadgets::biguint::BigUintTarget; -use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::{HashOut, HashOutTarget, RichField}; use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::{BoolTarget, Target}; @@ -169,21 +166,6 @@ impl GeneratedValues { self.set_target(target.0, F::from_canonical_u32(value)) } - pub fn set_biguint_target(&mut self, target: BigUintTarget, value: BigUint) { - let mut limbs = value.to_u32_digits(); - - assert!(target.num_limbs() >= limbs.len()); - - limbs.resize(target.num_limbs(), 0); - for i in 0..target.num_limbs() { - self.set_u32_target(target.get_limb(i), limbs[i]); - } - } - - pub fn set_nonnative_target(&mut self, target: NonNativeTarget, value: FF) { - self.set_biguint_target(target.value, value.to_canonical_biguint()) - } - pub fn set_hash_target(&mut self, ht: HashOutTarget, value: HashOut) { ht.elements .iter() diff --git a/plonky2/src/iop/witness.rs b/plonky2/src/iop/witness.rs index a013f811..4a565d02 100644 --- a/plonky2/src/iop/witness.rs +++ b/plonky2/src/iop/witness.rs @@ -1,16 +1,12 @@ use std::collections::HashMap; -use std::iter::repeat; use itertools::Itertools; -use num::{BigUint, FromPrimitive, Zero}; use plonky2_field::extension_field::{Extendable, FieldExtension}; -use plonky2_field::field_types::{Field, PrimeField}; +use plonky2_field::field_types::Field; use crate::fri::structure::{FriOpenings, FriOpeningsTarget}; use crate::fri::witness_util::set_fri_proof_target; use crate::gadgets::arithmetic_u32::U32Target; -use crate::gadgets::biguint::BigUintTarget; -use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::HashOutTarget; use crate::hash::hash_types::RichField; use crate::hash::hash_types::{HashOut, MerkleCapTarget}; @@ -64,30 +60,6 @@ pub trait Witness { panic!("not a bool") } - fn get_biguint_target(&self, target: BigUintTarget) -> BigUint - where - F: PrimeField, - { - let mut result = BigUint::zero(); - - let limb_base = BigUint::from_u64(1 << 32u64).unwrap(); - for i in (0..target.num_limbs()).rev() { - let limb = target.get_limb(i); - result *= &limb_base; - result += self.get_target(limb.0).to_canonical_biguint(); - } - - result - } - - fn get_nonnative_target(&self, target: NonNativeTarget) -> FF - where - F: PrimeField, - { - let val = self.get_biguint_target(target.value); - FF::from_biguint(val) - } - fn get_hash_target(&self, ht: HashOutTarget) -> HashOut { HashOut { elements: self.get_targets(&ht.elements).try_into().unwrap(), @@ -160,16 +132,6 @@ pub trait Witness { self.set_target(target.0, F::from_canonical_u32(value)) } - fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) { - for (<, l) in target - .limbs - .iter() - .zip(value.to_u32_digits().into_iter().chain(repeat(0))) - { - self.set_u32_target(lt, l); - } - } - /// Set the targets in a `ProofWithPublicInputsTarget` to their corresponding values in a /// `ProofWithPublicInputs`. fn set_proof_with_pis_target, const D: usize>( diff --git a/plonky2/src/lib.rs b/plonky2/src/lib.rs index e5e77bb9..1502cea9 100644 --- a/plonky2/src/lib.rs +++ b/plonky2/src/lib.rs @@ -12,7 +12,6 @@ pub use plonky2_field as field; -pub mod curve; pub mod fri; pub mod gadgets; pub mod gates; From 534ee7d637e7a69b98b856e0e5a6aaa4f56b6ad9 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 16 Mar 2022 05:39:45 +0100 Subject: [PATCH 53/56] Add untracked files --- ecdsa/Cargo.toml | 19 + ecdsa/src/curve/curve_adds.rs | 158 +++++ ecdsa/src/curve/curve_msm.rs | 265 ++++++++ ecdsa/src/curve/curve_multiplication.rs | 100 +++ ecdsa/src/curve/curve_summation.rs | 239 +++++++ ecdsa/src/curve/curve_types.rs | 285 ++++++++ ecdsa/src/curve/ecdsa.rs | 79 +++ ecdsa/src/curve/glv.rs | 136 ++++ ecdsa/src/curve/mod.rs | 8 + ecdsa/src/curve/secp256k1.rs | 101 +++ ecdsa/src/gadgets/biguint.rs | 505 +++++++++++++++ ecdsa/src/gadgets/curve.rs | 484 ++++++++++++++ ecdsa/src/gadgets/curve_fixed_base.rs | 113 ++++ ecdsa/src/gadgets/curve_msm.rs | 136 ++++ ecdsa/src/gadgets/curve_windowed_mul.rs | 256 ++++++++ ecdsa/src/gadgets/ecdsa.rs | 117 ++++ ecdsa/src/gadgets/glv.rs | 180 ++++++ ecdsa/src/gadgets/nonnative.rs | 822 ++++++++++++++++++++++++ ecdsa/src/gadgets/split_nonnative.rs | 131 ++++ ecdsa/src/lib.rs | 4 + 20 files changed, 4138 insertions(+) create mode 100644 ecdsa/Cargo.toml create mode 100644 ecdsa/src/curve/curve_adds.rs create mode 100644 ecdsa/src/curve/curve_msm.rs create mode 100644 ecdsa/src/curve/curve_multiplication.rs create mode 100644 ecdsa/src/curve/curve_summation.rs create mode 100644 ecdsa/src/curve/curve_types.rs create mode 100644 ecdsa/src/curve/ecdsa.rs create mode 100644 ecdsa/src/curve/glv.rs create mode 100644 ecdsa/src/curve/mod.rs create mode 100644 ecdsa/src/curve/secp256k1.rs create mode 100644 ecdsa/src/gadgets/biguint.rs create mode 100644 ecdsa/src/gadgets/curve.rs create mode 100644 ecdsa/src/gadgets/curve_fixed_base.rs create mode 100644 ecdsa/src/gadgets/curve_msm.rs create mode 100644 ecdsa/src/gadgets/curve_windowed_mul.rs create mode 100644 ecdsa/src/gadgets/ecdsa.rs create mode 100644 ecdsa/src/gadgets/glv.rs create mode 100644 ecdsa/src/gadgets/nonnative.rs create mode 100644 ecdsa/src/gadgets/split_nonnative.rs create mode 100644 ecdsa/src/lib.rs diff --git a/ecdsa/Cargo.toml b/ecdsa/Cargo.toml new file mode 100644 index 00000000..59ad9ee3 --- /dev/null +++ b/ecdsa/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "plonky2_ecdsa" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +plonky2 = { path = "../plonky2" } +plonky2_util = { path = "../util" } +plonky2_field = { path = "../field" } +num = "0.4.0" +itertools = "0.10.0" +rayon = "1.5.1" +serde = { version = "1.0", features = ["derive"] } +anyhow = "1.0.40" +rand = "0.8.4" +#env_logger = "0.9.0" +#log = "0.4.14" diff --git a/ecdsa/src/curve/curve_adds.rs b/ecdsa/src/curve/curve_adds.rs new file mode 100644 index 00000000..98dbc697 --- /dev/null +++ b/ecdsa/src/curve/curve_adds.rs @@ -0,0 +1,158 @@ +use std::ops::Add; + +use plonky2_field::field_types::Field; +use plonky2_field::ops::Square; + +use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; + +impl Add> for ProjectivePoint { + type Output = ProjectivePoint; + + fn add(self, rhs: ProjectivePoint) -> Self::Output { + let ProjectivePoint { + x: x1, + y: y1, + z: z1, + } = self; + let ProjectivePoint { + x: x2, + y: y2, + z: z2, + } = rhs; + + if z1 == C::BaseField::ZERO { + return rhs; + } + if z2 == C::BaseField::ZERO { + return self; + } + + let x1z2 = x1 * z2; + let y1z2 = y1 * z2; + let x2z1 = x2 * z1; + let y2z1 = y2 * z1; + + // Check if we're doubling or adding inverses. + if x1z2 == x2z1 { + if y1z2 == y2z1 { + // TODO: inline to avoid redundant muls. + return self.double(); + } + if y1z2 == -y2z1 { + return ProjectivePoint::ZERO; + } + } + + // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/add-1998-cmo-2 + let z1z2 = z1 * z2; + let u = y2z1 - y1z2; + let uu = u.square(); + let v = x2z1 - x1z2; + let vv = v.square(); + let vvv = v * vv; + let r = vv * x1z2; + let a = uu * z1z2 - vvv - r.double(); + let x3 = v * a; + let y3 = u * (r - a) - vvv * y1z2; + let z3 = vvv * z1z2; + ProjectivePoint::nonzero(x3, y3, z3) + } +} + +impl Add> for ProjectivePoint { + type Output = ProjectivePoint; + + fn add(self, rhs: AffinePoint) -> Self::Output { + let ProjectivePoint { + x: x1, + y: y1, + z: z1, + } = self; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = rhs; + + if z1 == C::BaseField::ZERO { + return rhs.to_projective(); + } + if zero2 { + return self; + } + + let x2z1 = x2 * z1; + let y2z1 = y2 * z1; + + // Check if we're doubling or adding inverses. + if x1 == x2z1 { + if y1 == y2z1 { + // TODO: inline to avoid redundant muls. + return self.double(); + } + if y1 == -y2z1 { + return ProjectivePoint::ZERO; + } + } + + // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/madd-1998-cmo + let u = y2z1 - y1; + let uu = u.square(); + let v = x2z1 - x1; + let vv = v.square(); + let vvv = v * vv; + let r = vv * x1; + let a = uu * z1 - vvv - r.double(); + let x3 = v * a; + let y3 = u * (r - a) - vvv * y1; + let z3 = vvv * z1; + ProjectivePoint::nonzero(x3, y3, z3) + } +} + +impl Add> for AffinePoint { + type Output = ProjectivePoint; + + fn add(self, rhs: AffinePoint) -> Self::Output { + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = self; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = rhs; + + if zero1 { + return rhs.to_projective(); + } + if zero2 { + return self.to_projective(); + } + + // Check if we're doubling or adding inverses. + if x1 == x2 { + if y1 == y2 { + return self.to_projective().double(); + } + if y1 == -y2 { + return ProjectivePoint::ZERO; + } + } + + // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/mmadd-1998-cmo + let u = y2 - y1; + let uu = u.square(); + let v = x2 - x1; + let vv = v.square(); + let vvv = v * vv; + let r = vv * x1; + let a = uu - vvv - r.double(); + let x3 = v * a; + let y3 = u * (r - a) - vvv * y1; + let z3 = vvv; + ProjectivePoint::nonzero(x3, y3, z3) + } +} diff --git a/ecdsa/src/curve/curve_msm.rs b/ecdsa/src/curve/curve_msm.rs new file mode 100644 index 00000000..4c274c1c --- /dev/null +++ b/ecdsa/src/curve/curve_msm.rs @@ -0,0 +1,265 @@ +use itertools::Itertools; +use plonky2_field::field_types::Field; +use plonky2_field::field_types::PrimeField; +use rayon::prelude::*; + +use crate::curve::curve_summation::affine_multisummation_best; +use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; + +/// In Yao's method, we compute an affine summation for each digit. In a parallel setting, it would +/// be easiest to assign individual summations to threads, but this would be sub-optimal because +/// multi-summations can be more efficient than repeating individual summations (see +/// `affine_multisummation_best`). Thus we divide digits into large chunks, and assign chunks of +/// digits to threads. Note that there is a delicate balance here, as large chunks can result in +/// uneven distributions of work among threads. +const DIGITS_PER_CHUNK: usize = 80; + +#[derive(Clone, Debug)] +pub struct MsmPrecomputation { + /// For each generator (in the order they were passed to `msm_precompute`), contains a vector + /// of powers, i.e. [(2^w)^i] for i < DIGITS. + // TODO: Use compressed coordinates here. + powers_per_generator: Vec>>, + + /// The window size. + w: usize, +} + +pub fn msm_precompute( + generators: &[ProjectivePoint], + w: usize, +) -> MsmPrecomputation { + MsmPrecomputation { + powers_per_generator: generators + .into_par_iter() + .map(|&g| precompute_single_generator(g, w)) + .collect(), + w, + } +} + +fn precompute_single_generator(g: ProjectivePoint, w: usize) -> Vec> { + let digits = (C::ScalarField::BITS + w - 1) / w; + let mut powers: Vec> = Vec::with_capacity(digits); + powers.push(g); + for i in 1..digits { + let mut power_i_proj = powers[i - 1]; + for _j in 0..w { + power_i_proj = power_i_proj.double(); + } + powers.push(power_i_proj); + } + ProjectivePoint::batch_to_affine(&powers) +} + +pub fn msm_parallel( + scalars: &[C::ScalarField], + generators: &[ProjectivePoint], + w: usize, +) -> ProjectivePoint { + let precomputation = msm_precompute(generators, w); + msm_execute_parallel(&precomputation, scalars) +} + +pub fn msm_execute( + precomputation: &MsmPrecomputation, + scalars: &[C::ScalarField], +) -> ProjectivePoint { + assert_eq!(precomputation.powers_per_generator.len(), scalars.len()); + let w = precomputation.w; + let digits = (C::ScalarField::BITS + w - 1) / w; + let base = 1 << w; + + // This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use + // extremely large windows, the repeated scans in Yao's method could be more expensive than the + // actual group operations. To avoid this, we store a multimap from each possible digit to the + // positions in which that digit occurs in the scalars. These positions have the form (i, j), + // where i is the index of the generator and j is an index into the digits of the scalar + // associated with that generator. + let mut digit_occurrences: Vec> = Vec::with_capacity(digits); + for _i in 0..base { + digit_occurrences.push(Vec::new()); + } + for (i, scalar) in scalars.iter().enumerate() { + let digits = to_digits::(scalar, w); + for (j, &digit) in digits.iter().enumerate() { + digit_occurrences[digit].push((i, j)); + } + } + + let mut y = ProjectivePoint::ZERO; + let mut u = ProjectivePoint::ZERO; + + for digit in (1..base).rev() { + for &(i, j) in &digit_occurrences[digit] { + u = u + precomputation.powers_per_generator[i][j]; + } + y = y + u; + } + + y +} + +pub fn msm_execute_parallel( + precomputation: &MsmPrecomputation, + scalars: &[C::ScalarField], +) -> ProjectivePoint { + assert_eq!(precomputation.powers_per_generator.len(), scalars.len()); + let w = precomputation.w; + let digits = (C::ScalarField::BITS + w - 1) / w; + let base = 1 << w; + + // This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use + // extremely large windows, the repeated scans in Yao's method could be more expensive than the + // actual group operations. To avoid this, we store a multimap from each possible digit to the + // positions in which that digit occurs in the scalars. These positions have the form (i, j), + // where i is the index of the generator and j is an index into the digits of the scalar + // associated with that generator. + let mut digit_occurrences: Vec> = Vec::with_capacity(digits); + for _i in 0..base { + digit_occurrences.push(Vec::new()); + } + for (i, scalar) in scalars.iter().enumerate() { + let digits = to_digits::(scalar, w); + for (j, &digit) in digits.iter().enumerate() { + digit_occurrences[digit].push((i, j)); + } + } + + // For each digit, we add up the powers associated with all occurrences that digit. + let digits: Vec = (0..base).collect(); + let digit_acc: Vec> = digits + .par_chunks(DIGITS_PER_CHUNK) + .flat_map(|chunk| { + let summations: Vec>> = chunk + .iter() + .map(|&digit| { + digit_occurrences[digit] + .iter() + .map(|&(i, j)| precomputation.powers_per_generator[i][j]) + .collect() + }) + .collect(); + affine_multisummation_best(summations) + }) + .collect(); + // println!("Computing the per-digit summations (in parallel) took {}s", start.elapsed().as_secs_f64()); + + let mut y = ProjectivePoint::ZERO; + let mut u = ProjectivePoint::ZERO; + for digit in (1..base).rev() { + u = u + digit_acc[digit]; + y = y + u; + } + // println!("Final summation (sequential) {}s", start.elapsed().as_secs_f64()); + y +} + +pub(crate) fn to_digits(x: &C::ScalarField, w: usize) -> Vec { + let scalar_bits = C::ScalarField::BITS; + let num_digits = (scalar_bits + w - 1) / w; + + // Convert x to a bool array. + let x_canonical: Vec<_> = x + .to_canonical_biguint() + .to_u64_digits() + .iter() + .cloned() + .pad_using(scalar_bits / 64, |_| 0) + .collect(); + let mut x_bits = Vec::with_capacity(scalar_bits); + for i in 0..scalar_bits { + x_bits.push((x_canonical[i / 64] >> (i as u64 % 64) & 1) != 0); + } + + let mut digits = Vec::with_capacity(num_digits); + for i in 0..num_digits { + let mut digit = 0; + for j in ((i * w)..((i + 1) * w).min(scalar_bits)).rev() { + digit <<= 1; + digit |= x_bits[j] as usize; + } + digits.push(digit); + } + digits +} + +#[cfg(test)] +mod tests { + use num::BigUint; + use plonky2_field::field_types::Field; + use plonky2_field::field_types::PrimeField; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::curve_msm::{msm_execute, msm_precompute, to_digits}; + use crate::curve::curve_types::Curve; + use crate::curve::secp256k1::Secp256K1; + + #[test] + fn test_to_digits() { + let x_canonical = [ + 0b10101010101010101010101010101010, + 0b10101010101010101010101010101010, + 0b11001100110011001100110011001100, + 0b11001100110011001100110011001100, + 0b11110000111100001111000011110000, + 0b11110000111100001111000011110000, + 0b00001111111111111111111111111111, + 0b11111111111111111111111111111111, + ]; + let x = Secp256K1Scalar::from_biguint(BigUint::from_slice(&x_canonical)); + assert_eq!(x.to_canonical_biguint().to_u32_digits(), x_canonical); + assert_eq!( + to_digits::(&x, 17), + vec![ + 0b01010101010101010, + 0b10101010101010101, + 0b01010101010101010, + 0b11001010101010101, + 0b01100110011001100, + 0b00110011001100110, + 0b10011001100110011, + 0b11110000110011001, + 0b01111000011110000, + 0b00111100001111000, + 0b00011110000111100, + 0b11111111111111110, + 0b01111111111111111, + 0b11111111111111000, + 0b11111111111111111, + 0b1, + ] + ); + } + + #[test] + fn test_msm() { + let w = 5; + + let generator_1 = Secp256K1::GENERATOR_PROJECTIVE; + let generator_2 = generator_1 + generator_1; + let generator_3 = generator_1 + generator_2; + + let scalar_1 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ + 11111111, 22222222, 33333333, 44444444, + ])); + let scalar_2 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ + 22222222, 22222222, 33333333, 44444444, + ])); + let scalar_3 = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ + 33333333, 22222222, 33333333, 44444444, + ])); + + let generators = vec![generator_1, generator_2, generator_3]; + let scalars = vec![scalar_1, scalar_2, scalar_3]; + + let precomputation = msm_precompute(&generators, w); + let result_msm = msm_execute(&precomputation, &scalars); + + let result_naive = Secp256K1::convert(scalar_1) * generator_1 + + Secp256K1::convert(scalar_2) * generator_2 + + Secp256K1::convert(scalar_3) * generator_3; + + assert_eq!(result_msm, result_naive); + } +} diff --git a/ecdsa/src/curve/curve_multiplication.rs b/ecdsa/src/curve/curve_multiplication.rs new file mode 100644 index 00000000..9f2accaf --- /dev/null +++ b/ecdsa/src/curve/curve_multiplication.rs @@ -0,0 +1,100 @@ +use std::ops::Mul; + +use plonky2_field::field_types::Field; +use plonky2_field::field_types::PrimeField; + +use crate::curve::curve_types::{Curve, CurveScalar, ProjectivePoint}; + +const WINDOW_BITS: usize = 4; +const BASE: usize = 1 << WINDOW_BITS; + +fn digits_per_scalar() -> usize { + (C::ScalarField::BITS + WINDOW_BITS - 1) / WINDOW_BITS +} + +/// Precomputed state used for scalar x ProjectivePoint multiplications, +/// specific to a particular generator. +#[derive(Clone)] +pub struct MultiplicationPrecomputation { + /// [(2^w)^i] g for each i < digits_per_scalar. + powers: Vec>, +} + +impl ProjectivePoint { + pub fn mul_precompute(&self) -> MultiplicationPrecomputation { + let num_digits = digits_per_scalar::(); + let mut powers = Vec::with_capacity(num_digits); + powers.push(*self); + for i in 1..num_digits { + let mut power_i = powers[i - 1]; + for _j in 0..WINDOW_BITS { + power_i = power_i.double(); + } + powers.push(power_i); + } + + MultiplicationPrecomputation { powers } + } + + #[must_use] + pub fn mul_with_precomputation( + &self, + scalar: C::ScalarField, + precomputation: MultiplicationPrecomputation, + ) -> Self { + // Yao's method; see https://koclab.cs.ucsb.edu/teaching/ecc/eccPapers/Doche-ch09.pdf + let precomputed_powers = precomputation.powers; + + let digits = to_digits::(&scalar); + + let mut y = ProjectivePoint::ZERO; + let mut u = ProjectivePoint::ZERO; + let mut all_summands = Vec::new(); + for j in (1..BASE).rev() { + let mut u_summands = Vec::new(); + for (i, &digit) in digits.iter().enumerate() { + if digit == j as u64 { + u_summands.push(precomputed_powers[i]); + } + } + all_summands.push(u_summands); + } + + let all_sums: Vec> = all_summands + .iter() + .cloned() + .map(|vec| vec.iter().fold(ProjectivePoint::ZERO, |a, &b| a + b)) + .collect(); + for i in 0..all_sums.len() { + u = u + all_sums[i]; + y = y + u; + } + y + } +} + +impl Mul> for CurveScalar { + type Output = ProjectivePoint; + + fn mul(self, rhs: ProjectivePoint) -> Self::Output { + let precomputation = rhs.mul_precompute(); + rhs.mul_with_precomputation(self.0, precomputation) + } +} + +#[allow(clippy::assertions_on_constants)] +fn to_digits(x: &C::ScalarField) -> Vec { + debug_assert!( + 64 % WINDOW_BITS == 0, + "For simplicity, only power-of-two window sizes are handled for now" + ); + let digits_per_u64 = 64 / WINDOW_BITS; + let mut digits = Vec::with_capacity(digits_per_scalar::()); + for limb in x.to_canonical_biguint().to_u64_digits() { + for j in 0..digits_per_u64 { + digits.push((limb >> (j * WINDOW_BITS) as u64) % BASE as u64); + } + } + + digits +} diff --git a/ecdsa/src/curve/curve_summation.rs b/ecdsa/src/curve/curve_summation.rs new file mode 100644 index 00000000..7ea01524 --- /dev/null +++ b/ecdsa/src/curve/curve_summation.rs @@ -0,0 +1,239 @@ +use std::iter::Sum; + +use plonky2_field::field_types::Field; +use plonky2_field::ops::Square; + +use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; + +impl Sum> for ProjectivePoint { + fn sum>>(iter: I) -> ProjectivePoint { + let points: Vec<_> = iter.collect(); + affine_summation_best(points) + } +} + +impl Sum for ProjectivePoint { + fn sum>>(iter: I) -> ProjectivePoint { + iter.fold(ProjectivePoint::ZERO, |acc, x| acc + x) + } +} + +pub fn affine_summation_best(summation: Vec>) -> ProjectivePoint { + let result = affine_multisummation_best(vec![summation]); + debug_assert_eq!(result.len(), 1); + result[0] +} + +pub fn affine_multisummation_best( + summations: Vec>>, +) -> Vec> { + let pairwise_sums: usize = summations.iter().map(|summation| summation.len() / 2).sum(); + + // This threshold is chosen based on data from the summation benchmarks. + if pairwise_sums < 70 { + affine_multisummation_pairwise(summations) + } else { + affine_multisummation_batch_inversion(summations) + } +} + +/// Adds each pair of points using an affine + affine = projective formula, then adds up the +/// intermediate sums using a projective formula. +pub fn affine_multisummation_pairwise( + summations: Vec>>, +) -> Vec> { + summations + .into_iter() + .map(affine_summation_pairwise) + .collect() +} + +/// Adds each pair of points using an affine + affine = projective formula, then adds up the +/// intermediate sums using a projective formula. +pub fn affine_summation_pairwise(points: Vec>) -> ProjectivePoint { + let mut reduced_points: Vec> = Vec::new(); + for chunk in points.chunks(2) { + match chunk.len() { + 1 => reduced_points.push(chunk[0].to_projective()), + 2 => reduced_points.push(chunk[0] + chunk[1]), + _ => panic!(), + } + } + // TODO: Avoid copying (deref) + reduced_points + .iter() + .fold(ProjectivePoint::ZERO, |sum, x| sum + *x) +} + +/// Computes several summations of affine points by applying an affine group law, except that the +/// divisions are batched via Montgomery's trick. +pub fn affine_summation_batch_inversion( + summation: Vec>, +) -> ProjectivePoint { + let result = affine_multisummation_batch_inversion(vec![summation]); + debug_assert_eq!(result.len(), 1); + result[0] +} + +/// Computes several summations of affine points by applying an affine group law, except that the +/// divisions are batched via Montgomery's trick. +pub fn affine_multisummation_batch_inversion( + summations: Vec>>, +) -> Vec> { + let mut elements_to_invert = Vec::new(); + + // For each pair of points, (x1, y1) and (x2, y2), that we're going to add later, we want to + // invert either y (if the points are equal) or x1 - x2 (otherwise). We will use these later. + for summation in &summations { + let n = summation.len(); + // The special case for n=0 is to avoid underflow. + let range_end = if n == 0 { 0 } else { n - 1 }; + + for i in (0..range_end).step_by(2) { + let p1 = summation[i]; + let p2 = summation[i + 1]; + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = p1; + let AffinePoint { + x: x2, + y: _y2, + zero: zero2, + } = p2; + + if zero1 || zero2 || p1 == -p2 { + // These are trivial cases where we won't need any inverse. + } else if p1 == p2 { + elements_to_invert.push(y1.double()); + } else { + elements_to_invert.push(x1 - x2); + } + } + } + + let inverses: Vec = + C::BaseField::batch_multiplicative_inverse(&elements_to_invert); + + let mut all_reduced_points = Vec::with_capacity(summations.len()); + let mut inverse_index = 0; + for summation in summations { + let n = summation.len(); + let mut reduced_points = Vec::with_capacity((n + 1) / 2); + + // The special case for n=0 is to avoid underflow. + let range_end = if n == 0 { 0 } else { n - 1 }; + + for i in (0..range_end).step_by(2) { + let p1 = summation[i]; + let p2 = summation[i + 1]; + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = p1; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = p2; + + let sum = if zero1 { + p2 + } else if zero2 { + p1 + } else if p1 == -p2 { + AffinePoint::ZERO + } else { + // It's a non-trivial case where we need one of the inverses we computed earlier. + let inverse = inverses[inverse_index]; + inverse_index += 1; + + if p1 == p2 { + // This is the doubling case. + let mut numerator = x1.square().triple(); + if C::A.is_nonzero() { + numerator += C::A; + } + let quotient = numerator * inverse; + let x3 = quotient.square() - x1.double(); + let y3 = quotient * (x1 - x3) - y1; + AffinePoint::nonzero(x3, y3) + } else { + // This is the general case. We use the incomplete addition formulas 4.3 and 4.4. + let quotient = (y1 - y2) * inverse; + let x3 = quotient.square() - x1 - x2; + let y3 = quotient * (x1 - x3) - y1; + AffinePoint::nonzero(x3, y3) + } + }; + reduced_points.push(sum); + } + + // If n is odd, the last point was not part of a pair. + if n % 2 == 1 { + reduced_points.push(summation[n - 1]); + } + + all_reduced_points.push(reduced_points); + } + + // We should have consumed all of the inverses from the batch computation. + debug_assert_eq!(inverse_index, inverses.len()); + + // Recurse with our smaller set of points. + affine_multisummation_best(all_reduced_points) +} + +#[cfg(test)] +mod tests { + use crate::curve::curve_summation::{ + affine_summation_batch_inversion, affine_summation_pairwise, + }; + use crate::curve::curve_types::{Curve, ProjectivePoint}; + use crate::curve::secp256k1::Secp256K1; + + #[test] + fn test_pairwise_affine_summation() { + let g_affine = Secp256K1::GENERATOR_AFFINE; + let g2_affine = (g_affine + g_affine).to_affine(); + let g3_affine = (g_affine + g_affine + g_affine).to_affine(); + let g2_proj = g2_affine.to_projective(); + let g3_proj = g3_affine.to_projective(); + assert_eq!( + affine_summation_pairwise::(vec![g_affine, g_affine]), + g2_proj + ); + assert_eq!( + affine_summation_pairwise::(vec![g_affine, g2_affine]), + g3_proj + ); + assert_eq!( + affine_summation_pairwise::(vec![g_affine, g_affine, g_affine]), + g3_proj + ); + assert_eq!( + affine_summation_pairwise::(vec![]), + ProjectivePoint::ZERO + ); + } + + #[test] + fn test_pairwise_affine_summation_batch_inversion() { + let g = Secp256K1::GENERATOR_AFFINE; + let g_proj = g.to_projective(); + assert_eq!( + affine_summation_batch_inversion::(vec![g, g]), + g_proj + g_proj + ); + assert_eq!( + affine_summation_batch_inversion::(vec![g, g, g]), + g_proj + g_proj + g_proj + ); + assert_eq!( + affine_summation_batch_inversion::(vec![]), + ProjectivePoint::ZERO + ); + } +} diff --git a/ecdsa/src/curve/curve_types.rs b/ecdsa/src/curve/curve_types.rs new file mode 100644 index 00000000..be025e12 --- /dev/null +++ b/ecdsa/src/curve/curve_types.rs @@ -0,0 +1,285 @@ +use std::fmt::Debug; +use std::hash::Hash; +use std::ops::Neg; + +use plonky2_field::field_types::{Field, PrimeField}; +use plonky2_field::ops::Square; +use serde::{Deserialize, Serialize}; + +// To avoid implementation conflicts from associated types, +// see https://github.com/rust-lang/rust/issues/20400 +pub struct CurveScalar(pub ::ScalarField); + +/// A short Weierstrass curve. +pub trait Curve: 'static + Sync + Sized + Copy + Debug { + type BaseField: PrimeField; + type ScalarField: PrimeField; + + const A: Self::BaseField; + const B: Self::BaseField; + + const GENERATOR_AFFINE: AffinePoint; + + const GENERATOR_PROJECTIVE: ProjectivePoint = ProjectivePoint { + x: Self::GENERATOR_AFFINE.x, + y: Self::GENERATOR_AFFINE.y, + z: Self::BaseField::ONE, + }; + + fn convert(x: Self::ScalarField) -> CurveScalar { + CurveScalar(x) + } + + fn is_safe_curve() -> bool { + // Added additional check to prevent using vulnerabilties in case a discriminant is equal to 0. + (Self::A.cube().double().double() + Self::B.square().triple().triple().triple()) + .is_nonzero() + } +} + +/// A point on a short Weierstrass curve, represented in affine coordinates. +#[derive(Copy, Clone, Debug, Deserialize, Serialize)] +pub struct AffinePoint { + pub x: C::BaseField, + pub y: C::BaseField, + pub zero: bool, +} + +impl AffinePoint { + pub const ZERO: Self = Self { + x: C::BaseField::ZERO, + y: C::BaseField::ZERO, + zero: true, + }; + + pub fn nonzero(x: C::BaseField, y: C::BaseField) -> Self { + let point = Self { x, y, zero: false }; + debug_assert!(point.is_valid()); + point + } + + pub fn is_valid(&self) -> bool { + let Self { x, y, zero } = *self; + zero || y.square() == x.cube() + C::A * x + C::B + } + + pub fn to_projective(&self) -> ProjectivePoint { + let Self { x, y, zero } = *self; + let z = if zero { + C::BaseField::ZERO + } else { + C::BaseField::ONE + }; + + ProjectivePoint { x, y, z } + } + + pub fn batch_to_projective(affine_points: &[Self]) -> Vec> { + affine_points.iter().map(Self::to_projective).collect() + } + + #[must_use] + pub fn double(&self) -> Self { + let AffinePoint { x: x1, y: y1, zero } = *self; + + if zero { + return AffinePoint::ZERO; + } + + let double_y = y1.double(); + let inv_double_y = double_y.inverse(); // (2y)^(-1) + let triple_xx = x1.square().triple(); // 3x^2 + let lambda = (triple_xx + C::A) * inv_double_y; + let x3 = lambda.square() - self.x.double(); + let y3 = lambda * (x1 - x3) - y1; + + Self { + x: x3, + y: y3, + zero: false, + } + } +} + +impl PartialEq for AffinePoint { + fn eq(&self, other: &Self) -> bool { + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = *self; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = *other; + if zero1 || zero2 { + return zero1 == zero2; + } + x1 == x2 && y1 == y2 + } +} + +impl Eq for AffinePoint {} + +impl Hash for AffinePoint { + fn hash(&self, state: &mut H) { + if self.zero { + self.zero.hash(state); + } else { + self.x.hash(state); + self.y.hash(state); + } + } +} + +/// A point on a short Weierstrass curve, represented in projective coordinates. +#[derive(Copy, Clone, Debug)] +pub struct ProjectivePoint { + pub x: C::BaseField, + pub y: C::BaseField, + pub z: C::BaseField, +} + +impl ProjectivePoint { + pub const ZERO: Self = Self { + x: C::BaseField::ZERO, + y: C::BaseField::ONE, + z: C::BaseField::ZERO, + }; + + pub fn nonzero(x: C::BaseField, y: C::BaseField, z: C::BaseField) -> Self { + let point = Self { x, y, z }; + debug_assert!(point.is_valid()); + point + } + + pub fn is_valid(&self) -> bool { + let Self { x, y, z } = *self; + z.is_zero() || y.square() * z == x.cube() + C::A * x * z.square() + C::B * z.cube() + } + + pub fn to_affine(&self) -> AffinePoint { + let Self { x, y, z } = *self; + if z == C::BaseField::ZERO { + AffinePoint::ZERO + } else { + let z_inv = z.inverse(); + AffinePoint::nonzero(x * z_inv, y * z_inv) + } + } + + pub fn batch_to_affine(proj_points: &[Self]) -> Vec> { + let n = proj_points.len(); + let zs: Vec = proj_points.iter().map(|pp| pp.z).collect(); + let z_invs = C::BaseField::batch_multiplicative_inverse(&zs); + + let mut result = Vec::with_capacity(n); + for i in 0..n { + let Self { x, y, z } = proj_points[i]; + result.push(if z == C::BaseField::ZERO { + AffinePoint::ZERO + } else { + let z_inv = z_invs[i]; + AffinePoint::nonzero(x * z_inv, y * z_inv) + }); + } + result + } + + // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/doubling/dbl-2007-bl + #[must_use] + pub fn double(&self) -> Self { + let Self { x, y, z } = *self; + if z == C::BaseField::ZERO { + return ProjectivePoint::ZERO; + } + + let xx = x.square(); + let zz = z.square(); + let mut w = xx.triple(); + if C::A.is_nonzero() { + w += C::A * zz; + } + let s = y.double() * z; + let r = y * s; + let rr = r.square(); + let b = (x + r).square() - (xx + rr); + let h = w.square() - b.double(); + let x3 = h * s; + let y3 = w * (b - h) - rr.double(); + let z3 = s.cube(); + Self { + x: x3, + y: y3, + z: z3, + } + } + + pub fn add_slices(a: &[Self], b: &[Self]) -> Vec { + assert_eq!(a.len(), b.len()); + a.iter() + .zip(b.iter()) + .map(|(&a_i, &b_i)| a_i + b_i) + .collect() + } + + #[must_use] + pub fn neg(&self) -> Self { + Self { + x: self.x, + y: -self.y, + z: self.z, + } + } +} + +impl PartialEq for ProjectivePoint { + fn eq(&self, other: &Self) -> bool { + let ProjectivePoint { + x: x1, + y: y1, + z: z1, + } = *self; + let ProjectivePoint { + x: x2, + y: y2, + z: z2, + } = *other; + if z1 == C::BaseField::ZERO || z2 == C::BaseField::ZERO { + return z1 == z2; + } + + // We want to compare (x1/z1, y1/z1) == (x2/z2, y2/z2). + // But to avoid field division, it is better to compare (x1*z2, y1*z2) == (x2*z1, y2*z1). + x1 * z2 == x2 * z1 && y1 * z2 == y2 * z1 + } +} + +impl Eq for ProjectivePoint {} + +impl Neg for AffinePoint { + type Output = AffinePoint; + + fn neg(self) -> Self::Output { + let AffinePoint { x, y, zero } = self; + AffinePoint { x, y: -y, zero } + } +} + +impl Neg for ProjectivePoint { + type Output = ProjectivePoint; + + fn neg(self) -> Self::Output { + let ProjectivePoint { x, y, z } = self; + ProjectivePoint { x, y: -y, z } + } +} + +pub fn base_to_scalar(x: C::BaseField) -> C::ScalarField { + C::ScalarField::from_biguint(x.to_canonical_biguint()) +} + +pub fn scalar_to_base(x: C::ScalarField) -> C::BaseField { + C::BaseField::from_biguint(x.to_canonical_biguint()) +} diff --git a/ecdsa/src/curve/ecdsa.rs b/ecdsa/src/curve/ecdsa.rs new file mode 100644 index 00000000..91b2cea9 --- /dev/null +++ b/ecdsa/src/curve/ecdsa.rs @@ -0,0 +1,79 @@ +use plonky2_field::field_types::Field; +use serde::{Deserialize, Serialize}; + +use crate::curve::curve_msm::msm_parallel; +use crate::curve::curve_types::{base_to_scalar, AffinePoint, Curve, CurveScalar}; + +#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +pub struct ECDSASignature { + pub r: C::ScalarField, + pub s: C::ScalarField, +} + +#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +pub struct ECDSASecretKey(pub C::ScalarField); + +#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +pub struct ECDSAPublicKey(pub AffinePoint); + +pub fn sign_message(msg: C::ScalarField, sk: ECDSASecretKey) -> ECDSASignature { + let (k, rr) = { + let mut k = C::ScalarField::rand(); + let mut rr = (CurveScalar(k) * C::GENERATOR_PROJECTIVE).to_affine(); + while rr.x == C::BaseField::ZERO { + k = C::ScalarField::rand(); + rr = (CurveScalar(k) * C::GENERATOR_PROJECTIVE).to_affine(); + } + (k, rr) + }; + let r = base_to_scalar::(rr.x); + + let s = k.inverse() * (msg + r * sk.0); + + ECDSASignature { r, s } +} + +pub fn verify_message( + msg: C::ScalarField, + sig: ECDSASignature, + pk: ECDSAPublicKey, +) -> bool { + let ECDSASignature { r, s } = sig; + + assert!(pk.0.is_valid()); + + let c = s.inverse(); + let u1 = msg * c; + let u2 = r * c; + + let g = C::GENERATOR_PROJECTIVE; + let w = 5; // Experimentally fastest + let point_proj = msm_parallel(&[u1, u2], &[g, pk.0.to_projective()], w); + let point = point_proj.to_affine(); + + let x = base_to_scalar::(point.x); + r == x +} + +#[cfg(test)] +mod tests { + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::ecdsa::{sign_message, verify_message, ECDSAPublicKey, ECDSASecretKey}; + use crate::curve::secp256k1::Secp256K1; + + #[test] + fn test_ecdsa_native() { + type C = Secp256K1; + + let msg = Secp256K1Scalar::rand(); + let sk = ECDSASecretKey(Secp256K1Scalar::rand()); + let pk = ECDSAPublicKey((CurveScalar(sk.0) * C::GENERATOR_PROJECTIVE).to_affine()); + + let sig = sign_message(msg, sk); + let result = verify_message(msg, sig, pk); + assert!(result); + } +} diff --git a/ecdsa/src/curve/glv.rs b/ecdsa/src/curve/glv.rs new file mode 100644 index 00000000..aeeb463e --- /dev/null +++ b/ecdsa/src/curve/glv.rs @@ -0,0 +1,136 @@ +use num::rational::Ratio; +use num::BigUint; +use plonky2_field::field_types::{Field, PrimeField}; +use plonky2_field::secp256k1_base::Secp256K1Base; +use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + +use crate::curve::curve_msm::msm_parallel; +use crate::curve::curve_types::{AffinePoint, ProjectivePoint}; +use crate::curve::secp256k1::Secp256K1; + +pub const GLV_BETA: Secp256K1Base = Secp256K1Base([ + 13923278643952681454, + 11308619431505398165, + 7954561588662645993, + 8856726876819556112, +]); + +pub const GLV_S: Secp256K1Scalar = Secp256K1Scalar([ + 16069571880186789234, + 1310022930574435960, + 11900229862571533402, + 6008836872998760672, +]); + +const A1: Secp256K1Scalar = Secp256K1Scalar([16747920425669159701, 3496713202691238861, 0, 0]); + +const MINUS_B1: Secp256K1Scalar = + Secp256K1Scalar([8022177200260244675, 16448129721693014056, 0, 0]); + +const A2: Secp256K1Scalar = Secp256K1Scalar([6323353552219852760, 1498098850674701302, 1, 0]); + +const B2: Secp256K1Scalar = Secp256K1Scalar([16747920425669159701, 3496713202691238861, 0, 0]); + +pub fn decompose_secp256k1_scalar( + k: Secp256K1Scalar, +) -> (Secp256K1Scalar, Secp256K1Scalar, bool, bool) { + let p = Secp256K1Scalar::order(); + let c1_biguint = Ratio::new( + B2.to_canonical_biguint() * k.to_canonical_biguint(), + p.clone(), + ) + .round() + .to_integer(); + let c1 = Secp256K1Scalar::from_biguint(c1_biguint); + let c2_biguint = Ratio::new( + MINUS_B1.to_canonical_biguint() * k.to_canonical_biguint(), + p.clone(), + ) + .round() + .to_integer(); + let c2 = Secp256K1Scalar::from_biguint(c2_biguint); + + let k1_raw = k - c1 * A1 - c2 * A2; + let k2_raw = c1 * MINUS_B1 - c2 * B2; + debug_assert!(k1_raw + GLV_S * k2_raw == k); + + let two = BigUint::from_slice(&[2]); + let k1_neg = k1_raw.to_canonical_biguint() > p.clone() / two.clone(); + let k1 = if k1_neg { + Secp256K1Scalar::from_biguint(p.clone() - k1_raw.to_canonical_biguint()) + } else { + k1_raw + }; + let k2_neg = k2_raw.to_canonical_biguint() > p.clone() / two; + let k2 = if k2_neg { + Secp256K1Scalar::from_biguint(p - k2_raw.to_canonical_biguint()) + } else { + k2_raw + }; + + (k1, k2, k1_neg, k2_neg) +} + +pub fn glv_mul(p: ProjectivePoint, k: Secp256K1Scalar) -> ProjectivePoint { + let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); + /*let one = Secp256K1Scalar::ONE; + let m1 = if k1_neg { -one } else { one }; + let m2 = if k2_neg { -one } else { one }; + assert!(k1 * m1 + S * k2 * m2 == k);*/ + + let p_affine = p.to_affine(); + let sp = AffinePoint:: { + x: p_affine.x * GLV_BETA, + y: p_affine.y, + zero: p_affine.zero, + }; + + let first = if k1_neg { p.neg() } else { p }; + let second = if k2_neg { + sp.to_projective().neg() + } else { + sp.to_projective() + }; + + msm_parallel(&[k1, k2], &[first, second], 5) +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::glv::{decompose_secp256k1_scalar, glv_mul, GLV_S}; + use crate::curve::secp256k1::Secp256K1; + + #[test] + fn test_glv_decompose() -> Result<()> { + let k = Secp256K1Scalar::rand(); + let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); + let one = Secp256K1Scalar::ONE; + let m1 = if k1_neg { -one } else { one }; + let m2 = if k2_neg { -one } else { one }; + + assert!(k1 * m1 + GLV_S * k2 * m2 == k); + + Ok(()) + } + + #[test] + fn test_glv_mul() -> Result<()> { + for _ in 0..20 { + let k = Secp256K1Scalar::rand(); + + let p = CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE; + + let kp = CurveScalar(k) * p; + let glv = glv_mul(p, k); + + assert!(kp == glv); + } + + Ok(()) + } +} diff --git a/ecdsa/src/curve/mod.rs b/ecdsa/src/curve/mod.rs new file mode 100644 index 00000000..1984b0c6 --- /dev/null +++ b/ecdsa/src/curve/mod.rs @@ -0,0 +1,8 @@ +pub mod curve_adds; +pub mod curve_msm; +pub mod curve_multiplication; +pub mod curve_summation; +pub mod curve_types; +pub mod ecdsa; +pub mod glv; +pub mod secp256k1; diff --git a/ecdsa/src/curve/secp256k1.rs b/ecdsa/src/curve/secp256k1.rs new file mode 100644 index 00000000..18040dae --- /dev/null +++ b/ecdsa/src/curve/secp256k1.rs @@ -0,0 +1,101 @@ +use plonky2_field::field_types::Field; +use plonky2_field::secp256k1_base::Secp256K1Base; +use plonky2_field::secp256k1_scalar::Secp256K1Scalar; +use serde::{Deserialize, Serialize}; + +use crate::curve::curve_types::{AffinePoint, Curve}; + +#[derive(Debug, Copy, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)] +pub struct Secp256K1; + +impl Curve for Secp256K1 { + type BaseField = Secp256K1Base; + type ScalarField = Secp256K1Scalar; + + const A: Secp256K1Base = Secp256K1Base::ZERO; + const B: Secp256K1Base = Secp256K1Base([7, 0, 0, 0]); + const GENERATOR_AFFINE: AffinePoint = AffinePoint { + x: SECP256K1_GENERATOR_X, + y: SECP256K1_GENERATOR_Y, + zero: false, + }; +} + +// 55066263022277343669578718895168534326250603453777594175500187360389116729240 +const SECP256K1_GENERATOR_X: Secp256K1Base = Secp256K1Base([ + 0x59F2815B16F81798, + 0x029BFCDB2DCE28D9, + 0x55A06295CE870B07, + 0x79BE667EF9DCBBAC, +]); + +/// 32670510020758816978083085130507043184471273380659243275938904335757337482424 +const SECP256K1_GENERATOR_Y: Secp256K1Base = Secp256K1Base([ + 0x9C47D08FFB10D4B8, + 0xFD17B448A6855419, + 0x5DA4FBFC0E1108A8, + 0x483ADA7726A3C465, +]); + +#[cfg(test)] +mod tests { + use num::BigUint; + use plonky2_field::field_types::Field; + use plonky2_field::field_types::PrimeField; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; + use crate::curve::secp256k1::Secp256K1; + + #[test] + fn test_generator() { + let g = Secp256K1::GENERATOR_AFFINE; + assert!(g.is_valid()); + + let neg_g = AffinePoint:: { + x: g.x, + y: -g.y, + zero: g.zero, + }; + assert!(neg_g.is_valid()); + } + + #[test] + fn test_naive_multiplication() { + let g = Secp256K1::GENERATOR_PROJECTIVE; + let ten = Secp256K1Scalar::from_canonical_u64(10); + let product = mul_naive(ten, g); + let sum = g + g + g + g + g + g + g + g + g + g; + assert_eq!(product, sum); + } + + #[test] + fn test_g1_multiplication() { + let lhs = Secp256K1Scalar::from_biguint(BigUint::from_slice(&[ + 1111, 2222, 3333, 4444, 5555, 6666, 7777, 8888, + ])); + assert_eq!( + Secp256K1::convert(lhs) * Secp256K1::GENERATOR_PROJECTIVE, + mul_naive(lhs, Secp256K1::GENERATOR_PROJECTIVE) + ); + } + + /// A simple, somewhat inefficient implementation of multiplication which is used as a reference + /// for correctness. + fn mul_naive( + lhs: Secp256K1Scalar, + rhs: ProjectivePoint, + ) -> ProjectivePoint { + let mut g = rhs; + let mut sum = ProjectivePoint::ZERO; + for limb in lhs.to_canonical_biguint().to_u64_digits().iter() { + for j in 0..64 { + if (limb >> j & 1u64) != 0u64 { + sum = sum + g; + } + g = g.double(); + } + } + sum + } +} diff --git a/ecdsa/src/gadgets/biguint.rs b/ecdsa/src/gadgets/biguint.rs new file mode 100644 index 00000000..5c077747 --- /dev/null +++ b/ecdsa/src/gadgets/biguint.rs @@ -0,0 +1,505 @@ +use std::marker::PhantomData; + +use num::{BigUint, Integer, Zero}; +use plonky2::gadgets::arithmetic_u32::U32Target; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::iop::witness::{PartitionWitness, Witness}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2_field::extension_field::Extendable; +use plonky2_field::field_types::PrimeField; + +#[derive(Clone, Debug)] +pub struct BigUintTarget { + pub limbs: Vec, +} + +impl BigUintTarget { + pub fn num_limbs(&self) -> usize { + self.limbs.len() + } + + pub fn get_limb(&self, i: usize) -> U32Target { + self.limbs[i] + } +} + +pub trait CircuitBuilderBiguint, const D: usize> { + fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget; + + fn zero_biguint(&mut self) -> BigUintTarget; + + fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget); + + fn pad_biguints( + &mut self, + a: &BigUintTarget, + b: &BigUintTarget, + ) -> (BigUintTarget, BigUintTarget); + + fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget; + + fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget; + + // Add two `BigUintTarget`s. + fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; + + // Subtract two `BigUintTarget`s. We assume that the first is larger than the second. + fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; + + fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; + + fn mul_biguint_by_bool(&mut self, a: &BigUintTarget, b: BoolTarget) -> BigUintTarget; + + // Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function). + fn mul_add_biguint( + &mut self, + x: &BigUintTarget, + y: &BigUintTarget, + z: &BigUintTarget, + ) -> BigUintTarget; + + fn div_rem_biguint( + &mut self, + a: &BigUintTarget, + b: &BigUintTarget, + ) -> (BigUintTarget, BigUintTarget); + + fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; + + fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; +} + +impl, const D: usize> CircuitBuilderBiguint + for CircuitBuilder +{ + fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget { + let limb_values = value.to_u32_digits(); + let limbs = limb_values.iter().map(|&l| self.constant_u32(l)).collect(); + + BigUintTarget { limbs } + } + + fn zero_biguint(&mut self) -> BigUintTarget { + self.constant_biguint(&BigUint::zero()) + } + + fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) { + let min_limbs = lhs.num_limbs().min(rhs.num_limbs()); + for i in 0..min_limbs { + self.connect_u32(lhs.get_limb(i), rhs.get_limb(i)); + } + + for i in min_limbs..lhs.num_limbs() { + self.assert_zero_u32(lhs.get_limb(i)); + } + for i in min_limbs..rhs.num_limbs() { + self.assert_zero_u32(rhs.get_limb(i)); + } + } + + fn pad_biguints( + &mut self, + a: &BigUintTarget, + b: &BigUintTarget, + ) -> (BigUintTarget, BigUintTarget) { + if a.num_limbs() > b.num_limbs() { + let mut padded_b = b.clone(); + for _ in b.num_limbs()..a.num_limbs() { + padded_b.limbs.push(self.zero_u32()); + } + + (a.clone(), padded_b) + } else { + let mut padded_a = a.clone(); + for _ in a.num_limbs()..b.num_limbs() { + padded_a.limbs.push(self.zero_u32()); + } + + (padded_a, b.clone()) + } + } + + fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget { + let (a, b) = self.pad_biguints(a, b); + + self.list_le_u32(a.limbs, b.limbs) + } + + fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget { + let limbs = self.add_virtual_u32_targets(num_limbs); + + BigUintTarget { limbs } + } + + // Add two `BigUintTarget`s. + fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let num_limbs = a.num_limbs().max(b.num_limbs()); + + let mut combined_limbs = vec![]; + let mut carry = self.zero_u32(); + for i in 0..num_limbs { + let a_limb = (i < a.num_limbs()) + .then(|| a.limbs[i]) + .unwrap_or_else(|| self.zero_u32()); + let b_limb = (i < b.num_limbs()) + .then(|| b.limbs[i]) + .unwrap_or_else(|| self.zero_u32()); + + let (new_limb, new_carry) = self.add_many_u32(&[carry, a_limb, b_limb]); + carry = new_carry; + combined_limbs.push(new_limb); + } + combined_limbs.push(carry); + + BigUintTarget { + limbs: combined_limbs, + } + } + + // Subtract two `BigUintTarget`s. We assume that the first is larger than the second. + fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let (a, b) = self.pad_biguints(a, b); + let num_limbs = a.limbs.len(); + + let mut result_limbs = vec![]; + + let mut borrow = self.zero_u32(); + for i in 0..num_limbs { + let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow); + result_limbs.push(result); + borrow = new_borrow; + } + // Borrow should be zero here. + + BigUintTarget { + limbs: result_limbs, + } + } + + fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let total_limbs = a.limbs.len() + b.limbs.len(); + + let mut to_add = vec![vec![]; total_limbs]; + for i in 0..a.limbs.len() { + for j in 0..b.limbs.len() { + let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]); + to_add[i + j].push(product); + to_add[i + j + 1].push(carry); + } + } + + let mut combined_limbs = vec![]; + let mut carry = self.zero_u32(); + for summands in &mut to_add { + let (new_result, new_carry) = self.add_u32s_with_carry(summands, carry); + combined_limbs.push(new_result); + carry = new_carry; + } + combined_limbs.push(carry); + + BigUintTarget { + limbs: combined_limbs, + } + } + + fn mul_biguint_by_bool(&mut self, a: &BigUintTarget, b: BoolTarget) -> BigUintTarget { + let t = b.target; + + BigUintTarget { + limbs: a + .limbs + .iter() + .map(|&l| U32Target(self.mul(l.0, t))) + .collect(), + } + } + + // Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function). + fn mul_add_biguint( + &mut self, + x: &BigUintTarget, + y: &BigUintTarget, + z: &BigUintTarget, + ) -> BigUintTarget { + let prod = self.mul_biguint(x, y); + self.add_biguint(&prod, z) + } + + fn div_rem_biguint( + &mut self, + a: &BigUintTarget, + b: &BigUintTarget, + ) -> (BigUintTarget, BigUintTarget) { + let a_len = a.limbs.len(); + let b_len = b.limbs.len(); + let div_num_limbs = if b_len > a_len + 1 { + 0 + } else { + a_len - b_len + 1 + }; + let div = self.add_virtual_biguint_target(div_num_limbs); + let rem = self.add_virtual_biguint_target(b_len); + + self.add_simple_generator(BigUintDivRemGenerator:: { + a: a.clone(), + b: b.clone(), + div: div.clone(), + rem: rem.clone(), + _phantom: PhantomData, + }); + + let div_b = self.mul_biguint(&div, b); + let div_b_plus_rem = self.add_biguint(&div_b, &rem); + self.connect_biguint(a, &div_b_plus_rem); + + let cmp_rem_b = self.cmp_biguint(&rem, b); + self.assert_one(cmp_rem_b.target); + + (div, rem) + } + + fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let (div, _rem) = self.div_rem_biguint(a, b); + div + } + + fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let (_div, rem) = self.div_rem_biguint(a, b); + rem + } +} + +pub fn witness_get_biguint_target, F: PrimeField>( + witness: &W, + bt: BigUintTarget, +) -> BigUint { + let base = BigUint::from(1usize << 32); + bt.limbs + .into_iter() + .rev() + .fold(BigUint::zero(), |acc, limb| { + acc * &base + witness.get_target(limb.0).to_canonical_biguint() + }) +} + +pub fn witness_set_biguint_target, F: PrimeField>( + witness: &mut W, + target: &BigUintTarget, + value: &BigUint, +) { + let mut limbs = value.to_u32_digits(); + assert!(target.num_limbs() >= limbs.len()); + limbs.resize(target.num_limbs(), 0); + for i in 0..target.num_limbs() { + witness.set_u32_target(target.limbs[i], limbs[i]); + } +} + +pub fn buffer_set_biguint_target( + buffer: &mut GeneratedValues, + target: &BigUintTarget, + value: &BigUint, +) { + let mut limbs = value.to_u32_digits(); + assert!(target.num_limbs() >= limbs.len()); + limbs.resize(target.num_limbs(), 0); + for i in 0..target.num_limbs() { + buffer.set_u32_target(target.get_limb(i), limbs[i]); + } +} + +#[derive(Debug)] +struct BigUintDivRemGenerator, const D: usize> { + a: BigUintTarget, + b: BigUintTarget, + div: BigUintTarget, + rem: BigUintTarget, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for BigUintDivRemGenerator +{ + fn dependencies(&self) -> Vec { + self.a + .limbs + .iter() + .chain(&self.b.limbs) + .map(|&l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = witness_get_biguint_target(witness, self.a.clone()); + let b = witness_get_biguint_target(witness, self.b.clone()); + let (div, rem) = a.div_rem(&b); + + buffer_set_biguint_target(out_buffer, &self.div, &div); + buffer_set_biguint_target(out_buffer, &self.rem, &rem); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use num::{BigUint, FromPrimitive, Integer}; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2::{ + iop::witness::PartialWitness, + plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig}, + }; + use rand::Rng; + + use crate::gadgets::biguint::{witness_set_biguint_target, CircuitBuilderBiguint}; + + #[test] + fn test_biguint_add() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let mut rng = rand::thread_rng(); + + let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let y_value = BigUint::from_u128(rng.gen()).unwrap(); + let expected_z_value = &x_value + &y_value; + + let config = CircuitConfig::standard_recursion_config(); + let mut pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len()); + let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len()); + let z = builder.add_biguint(&x, &y); + let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); + builder.connect_biguint(&z, &expected_z); + + witness_set_biguint_target(&mut pw, &x, &x_value); + witness_set_biguint_target(&mut pw, &y, &y_value); + witness_set_biguint_target(&mut pw, &expected_z, &expected_z_value); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_biguint_sub() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let mut rng = rand::thread_rng(); + + let mut x_value = BigUint::from_u128(rng.gen()).unwrap(); + let mut y_value = BigUint::from_u128(rng.gen()).unwrap(); + if y_value > x_value { + (x_value, y_value) = (y_value, x_value); + } + let expected_z_value = &x_value - &y_value; + + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let z = builder.sub_biguint(&x, &y); + let expected_z = builder.constant_biguint(&expected_z_value); + + builder.connect_biguint(&z, &expected_z); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_biguint_mul() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let mut rng = rand::thread_rng(); + + let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let y_value = BigUint::from_u128(rng.gen()).unwrap(); + let expected_z_value = &x_value * &y_value; + + let config = CircuitConfig::standard_recursion_config(); + let mut pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len()); + let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len()); + let z = builder.mul_biguint(&x, &y); + let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); + builder.connect_biguint(&z, &expected_z); + + witness_set_biguint_target(&mut pw, &x, &x_value); + witness_set_biguint_target(&mut pw, &y, &y_value); + witness_set_biguint_target(&mut pw, &expected_z, &expected_z_value); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_biguint_cmp() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let mut rng = rand::thread_rng(); + + let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let y_value = BigUint::from_u128(rng.gen()).unwrap(); + + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let cmp = builder.cmp_biguint(&x, &y); + let expected_cmp = builder.constant_bool(x_value <= y_value); + + builder.connect(cmp.target, expected_cmp.target); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_biguint_div_rem() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let mut rng = rand::thread_rng(); + + let mut x_value = BigUint::from_u128(rng.gen()).unwrap(); + let mut y_value = BigUint::from_u128(rng.gen()).unwrap(); + if y_value > x_value { + (x_value, y_value) = (y_value, x_value); + } + let (expected_div_value, expected_rem_value) = x_value.div_rem(&y_value); + + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let (div, rem) = builder.div_rem_biguint(&x, &y); + + let expected_div = builder.constant_biguint(&expected_div_value); + let expected_rem = builder.constant_biguint(&expected_rem_value); + + builder.connect_biguint(&div, &expected_div); + builder.connect_biguint(&rem, &expected_rem); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } +} diff --git a/ecdsa/src/gadgets/curve.rs b/ecdsa/src/gadgets/curve.rs new file mode 100644 index 00000000..e8a277c4 --- /dev/null +++ b/ecdsa/src/gadgets/curve.rs @@ -0,0 +1,484 @@ +use plonky2::hash::hash_types::RichField; +use plonky2::iop::target::BoolTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2_field::extension_field::Extendable; +use plonky2_field::field_types::Field; + +use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; +use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; + +/// A Target representing an affine point on the curve `C`. We use incomplete arithmetic for efficiency, +/// so we assume these points are not zero. +#[derive(Clone, Debug)] +pub struct AffinePointTarget { + pub x: NonNativeTarget, + pub y: NonNativeTarget, +} + +impl AffinePointTarget { + pub fn to_vec(&self) -> Vec> { + vec![self.x.clone(), self.y.clone()] + } +} + +pub trait CircuitBuilderCurve, const D: usize> { + fn constant_affine_point(&mut self, point: AffinePoint) -> AffinePointTarget; + + fn connect_affine_point( + &mut self, + lhs: &AffinePointTarget, + rhs: &AffinePointTarget, + ); + + fn add_virtual_affine_point_target(&mut self) -> AffinePointTarget; + + fn curve_assert_valid(&mut self, p: &AffinePointTarget); + + fn curve_neg(&mut self, p: &AffinePointTarget) -> AffinePointTarget; + + fn curve_conditional_neg( + &mut self, + p: &AffinePointTarget, + b: BoolTarget, + ) -> AffinePointTarget; + + fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget; + + fn curve_repeated_double( + &mut self, + p: &AffinePointTarget, + n: usize, + ) -> AffinePointTarget; + + // Add two points, which are assumed to be non-equal. + fn curve_add( + &mut self, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + ) -> AffinePointTarget; + + fn curve_conditional_add( + &mut self, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + b: BoolTarget, + ) -> AffinePointTarget; + + fn curve_scalar_mul( + &mut self, + p: &AffinePointTarget, + n: &NonNativeTarget, + ) -> AffinePointTarget; +} + +impl, const D: usize> CircuitBuilderCurve + for CircuitBuilder +{ + fn constant_affine_point(&mut self, point: AffinePoint) -> AffinePointTarget { + debug_assert!(!point.zero); + AffinePointTarget { + x: self.constant_nonnative(point.x), + y: self.constant_nonnative(point.y), + } + } + + fn connect_affine_point( + &mut self, + lhs: &AffinePointTarget, + rhs: &AffinePointTarget, + ) { + self.connect_nonnative(&lhs.x, &rhs.x); + self.connect_nonnative(&lhs.y, &rhs.y); + } + + fn add_virtual_affine_point_target(&mut self) -> AffinePointTarget { + let x = self.add_virtual_nonnative_target(); + let y = self.add_virtual_nonnative_target(); + + AffinePointTarget { x, y } + } + + fn curve_assert_valid(&mut self, p: &AffinePointTarget) { + let a = self.constant_nonnative(C::A); + let b = self.constant_nonnative(C::B); + + let y_squared = self.mul_nonnative(&p.y, &p.y); + let x_squared = self.mul_nonnative(&p.x, &p.x); + let x_cubed = self.mul_nonnative(&x_squared, &p.x); + let a_x = self.mul_nonnative(&a, &p.x); + let a_x_plus_b = self.add_nonnative(&a_x, &b); + let rhs = self.add_nonnative(&x_cubed, &a_x_plus_b); + + self.connect_nonnative(&y_squared, &rhs); + } + + fn curve_neg(&mut self, p: &AffinePointTarget) -> AffinePointTarget { + let neg_y = self.neg_nonnative(&p.y); + AffinePointTarget { + x: p.x.clone(), + y: neg_y, + } + } + + fn curve_conditional_neg( + &mut self, + p: &AffinePointTarget, + b: BoolTarget, + ) -> AffinePointTarget { + AffinePointTarget { + x: p.x.clone(), + y: self.nonnative_conditional_neg(&p.y, b), + } + } + + fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget { + let AffinePointTarget { x, y } = p; + let double_y = self.add_nonnative(y, y); + let inv_double_y = self.inv_nonnative(&double_y); + let x_squared = self.mul_nonnative(x, x); + let double_x_squared = self.add_nonnative(&x_squared, &x_squared); + let triple_x_squared = self.add_nonnative(&double_x_squared, &x_squared); + + let a = self.constant_nonnative(C::A); + let triple_xx_a = self.add_nonnative(&triple_x_squared, &a); + let lambda = self.mul_nonnative(&triple_xx_a, &inv_double_y); + let lambda_squared = self.mul_nonnative(&lambda, &lambda); + let x_double = self.add_nonnative(x, x); + + let x3 = self.sub_nonnative(&lambda_squared, &x_double); + + let x_diff = self.sub_nonnative(x, &x3); + let lambda_x_diff = self.mul_nonnative(&lambda, &x_diff); + + let y3 = self.sub_nonnative(&lambda_x_diff, y); + + AffinePointTarget { x: x3, y: y3 } + } + + fn curve_repeated_double( + &mut self, + p: &AffinePointTarget, + n: usize, + ) -> AffinePointTarget { + let mut result = p.clone(); + + for _ in 0..n { + result = self.curve_double(&result); + } + + result + } + + // Add two points, which are assumed to be non-equal. + fn curve_add( + &mut self, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + ) -> AffinePointTarget { + let AffinePointTarget { x: x1, y: y1 } = p1; + let AffinePointTarget { x: x2, y: y2 } = p2; + + let u = self.sub_nonnative(y2, y1); + let v = self.sub_nonnative(x2, x1); + let v_inv = self.inv_nonnative(&v); + let s = self.mul_nonnative(&u, &v_inv); + let s_squared = self.mul_nonnative(&s, &s); + let x_sum = self.add_nonnative(x2, x1); + let x3 = self.sub_nonnative(&s_squared, &x_sum); + let x_diff = self.sub_nonnative(x1, &x3); + let prod = self.mul_nonnative(&s, &x_diff); + let y3 = self.sub_nonnative(&prod, y1); + + AffinePointTarget { x: x3, y: y3 } + } + + fn curve_conditional_add( + &mut self, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + b: BoolTarget, + ) -> AffinePointTarget { + let not_b = self.not(b); + let sum = self.curve_add(p1, p2); + let x_if_true = self.mul_nonnative_by_bool(&sum.x, b); + let y_if_true = self.mul_nonnative_by_bool(&sum.y, b); + let x_if_false = self.mul_nonnative_by_bool(&p1.x, not_b); + let y_if_false = self.mul_nonnative_by_bool(&p1.y, not_b); + + let x = self.add_nonnative(&x_if_true, &x_if_false); + let y = self.add_nonnative(&y_if_true, &y_if_false); + + AffinePointTarget { x, y } + } + + fn curve_scalar_mul( + &mut self, + p: &AffinePointTarget, + n: &NonNativeTarget, + ) -> AffinePointTarget { + let bits = self.split_nonnative_to_bits(n); + + let rando = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine(); + let randot = self.constant_affine_point(rando); + // Result starts at `rando`, which is later subtracted, because we don't support arithmetic with the zero point. + let mut result = self.add_virtual_affine_point_target(); + self.connect_affine_point(&randot, &result); + + let mut two_i_times_p = self.add_virtual_affine_point_target(); + self.connect_affine_point(p, &two_i_times_p); + + for &bit in bits.iter() { + let not_bit = self.not(bit); + + let result_plus_2_i_p = self.curve_add(&result, &two_i_times_p); + + let new_x_if_bit = self.mul_nonnative_by_bool(&result_plus_2_i_p.x, bit); + let new_x_if_not_bit = self.mul_nonnative_by_bool(&result.x, not_bit); + let new_y_if_bit = self.mul_nonnative_by_bool(&result_plus_2_i_p.y, bit); + let new_y_if_not_bit = self.mul_nonnative_by_bool(&result.y, not_bit); + + let new_x = self.add_nonnative(&new_x_if_bit, &new_x_if_not_bit); + let new_y = self.add_nonnative(&new_y_if_bit, &new_y_if_not_bit); + + result = AffinePointTarget { x: new_x, y: new_y }; + + two_i_times_p = self.curve_double(&two_i_times_p); + } + + // Subtract off result's intial value of `rando`. + let neg_r = self.curve_neg(&randot); + result = self.curve_add(&result, &neg_r); + + result + } +} + +#[cfg(test)] +mod tests { + use std::ops::Neg; + + use anyhow::Result; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_base::Secp256K1Base; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; + use crate::curve::secp256k1::Secp256K1; + use crate::gadgets::curve::CircuitBuilderCurve; + use crate::gadgets::nonnative::CircuitBuilderNonNative; + + #[test] + fn test_curve_point_is_valid() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let g_target = builder.constant_affine_point(g); + let neg_g_target = builder.curve_neg(&g_target); + + builder.curve_assert_valid(&g_target); + builder.curve_assert_valid(&neg_g_target); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } + + #[test] + #[should_panic] + fn test_curve_point_is_not_valid() { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let not_g = AffinePoint:: { + x: g.x, + y: g.y + Secp256K1Base::ONE, + zero: g.zero, + }; + let not_g_target = builder.constant_affine_point(not_g); + + builder.curve_assert_valid(¬_g_target); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof).unwrap() + } + + #[test] + fn test_curve_double() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let g_target = builder.constant_affine_point(g); + let neg_g_target = builder.curve_neg(&g_target); + + let double_g = g.double(); + let double_g_expected = builder.constant_affine_point(double_g); + builder.curve_assert_valid(&double_g_expected); + + let double_neg_g = (-g).double(); + let double_neg_g_expected = builder.constant_affine_point(double_neg_g); + builder.curve_assert_valid(&double_neg_g_expected); + + let double_g_actual = builder.curve_double(&g_target); + let double_neg_g_actual = builder.curve_double(&neg_g_target); + builder.curve_assert_valid(&double_g_actual); + builder.curve_assert_valid(&double_neg_g_actual); + + builder.connect_affine_point(&double_g_expected, &double_g_actual); + builder.connect_affine_point(&double_neg_g_expected, &double_neg_g_actual); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } + + #[test] + fn test_curve_add() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let double_g = g.double(); + let g_plus_2g = (g + double_g).to_affine(); + let g_plus_2g_expected = builder.constant_affine_point(g_plus_2g); + builder.curve_assert_valid(&g_plus_2g_expected); + + let g_target = builder.constant_affine_point(g); + let double_g_target = builder.curve_double(&g_target); + let g_plus_2g_actual = builder.curve_add(&g_target, &double_g_target); + builder.curve_assert_valid(&g_plus_2g_actual); + + builder.connect_affine_point(&g_plus_2g_expected, &g_plus_2g_actual); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } + + #[test] + fn test_curve_conditional_add() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let double_g = g.double(); + let g_plus_2g = (g + double_g).to_affine(); + let g_plus_2g_expected = builder.constant_affine_point(g_plus_2g); + + let g_expected = builder.constant_affine_point(g); + let double_g_target = builder.curve_double(&g_expected); + let t = builder._true(); + let f = builder._false(); + let g_plus_2g_actual = builder.curve_conditional_add(&g_expected, &double_g_target, t); + let g_actual = builder.curve_conditional_add(&g_expected, &double_g_target, f); + + builder.connect_affine_point(&g_plus_2g_expected, &g_plus_2g_actual); + builder.connect_affine_point(&g_expected, &g_actual); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } + + #[test] + #[ignore] + fn test_curve_mul() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_PROJECTIVE.to_affine(); + let five = Secp256K1Scalar::from_canonical_usize(5); + let neg_five = five.neg(); + let neg_five_scalar = CurveScalar::(neg_five); + let neg_five_g = (neg_five_scalar * g.to_projective()).to_affine(); + let neg_five_g_expected = builder.constant_affine_point(neg_five_g); + builder.curve_assert_valid(&neg_five_g_expected); + + let g_target = builder.constant_affine_point(g); + let neg_five_target = builder.constant_nonnative(neg_five); + let neg_five_g_actual = builder.curve_scalar_mul(&g_target, &neg_five_target); + builder.curve_assert_valid(&neg_five_g_actual); + + builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } + + #[test] + #[ignore] + fn test_curve_random() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let rando = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let randot = builder.constant_affine_point(rando); + + let two_target = builder.constant_nonnative(Secp256K1Scalar::TWO); + let randot_doubled = builder.curve_double(&randot); + let randot_times_two = builder.curve_scalar_mul(&randot, &two_target); + builder.connect_affine_point(&randot_doubled, &randot_times_two); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } +} diff --git a/ecdsa/src/gadgets/curve_fixed_base.rs b/ecdsa/src/gadgets/curve_fixed_base.rs new file mode 100644 index 00000000..e64ec134 --- /dev/null +++ b/ecdsa/src/gadgets/curve_fixed_base.rs @@ -0,0 +1,113 @@ +use num::BigUint; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::keccak::KeccakHash; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::config::{GenericHashOut, Hasher}; +use plonky2_field::extension_field::Extendable; +use plonky2_field::field_types::Field; + +use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; +use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; +use crate::gadgets::curve_windowed_mul::CircuitBuilderWindowedMul; +use crate::gadgets::nonnative::NonNativeTarget; +use crate::gadgets::split_nonnative::CircuitBuilderSplit; + +/// Compute windowed fixed-base scalar multiplication, using a 4-bit window. +pub fn fixed_base_curve_mul_circuit, const D: usize>( + builder: &mut CircuitBuilder, + base: AffinePoint, + scalar: &NonNativeTarget, +) -> AffinePointTarget { + // Holds `(16^i) * base` for `i=0..scalar.value.limbs.len() * 8`. + let scaled_base = (0..scalar.value.limbs.len() * 8).scan(base, |acc, _| { + let tmp = *acc; + for _ in 0..4 { + *acc = acc.double(); + } + Some(tmp) + }); + + let limbs = builder.split_nonnative_to_4_bit_limbs(scalar); + + let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); + let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( + &GenericHashOut::::to_bytes(&hash_0), + )); + let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); + + let zero = builder.zero(); + let mut result = builder.constant_affine_point(rando); + // `s * P = sum s_i * P_i` with `P_i = (16^i) * P` and `s = sum s_i * (16^i)`. + for (limb, point) in limbs.into_iter().zip(scaled_base) { + // `muls_point[t] = t * P_i` for `t=0..16`. + let muls_point = (0..16) + .scan(AffinePoint::ZERO, |acc, _| { + let tmp = *acc; + *acc = (point + *acc).to_affine(); + Some(tmp) + }) + .map(|p| builder.constant_affine_point(p)) + .collect::>(); + let is_zero = builder.is_equal(limb, zero); + let should_add = builder.not(is_zero); + // `r = s_i * P_i` + let r = builder.random_access_curve_points(limb, muls_point); + result = builder.curve_conditional_add(&result, &r, should_add); + } + + let to_add = builder.constant_affine_point(-rando); + builder.curve_add(&result, &to_add) +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2_field::field_types::Field; + use plonky2_field::field_types::PrimeField; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::secp256k1::Secp256K1; + use crate::gadgets::biguint::witness_set_biguint_target; + use crate::gadgets::curve::CircuitBuilderCurve; + use crate::gadgets::curve_fixed_base::fixed_base_curve_mul_circuit; + use crate::gadgets::nonnative::CircuitBuilderNonNative; + + #[test] + #[ignore] + fn test_fixed_base() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let mut pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let n = Secp256K1Scalar::rand(); + + let res = (CurveScalar(n) * g.to_projective()).to_affine(); + let res_expected = builder.constant_affine_point(res); + builder.curve_assert_valid(&res_expected); + + let n_target = builder.add_virtual_nonnative_target::(); + witness_set_biguint_target(&mut pw, &n_target.value, &n.to_canonical_biguint()); + + let res_target = fixed_base_curve_mul_circuit(&mut builder, g, &n_target); + builder.curve_assert_valid(&res_target); + + builder.connect_affine_point(&res_target, &res_expected); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } +} diff --git a/ecdsa/src/gadgets/curve_msm.rs b/ecdsa/src/gadgets/curve_msm.rs new file mode 100644 index 00000000..c57cecd2 --- /dev/null +++ b/ecdsa/src/gadgets/curve_msm.rs @@ -0,0 +1,136 @@ +use num::BigUint; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::keccak::KeccakHash; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::config::{GenericHashOut, Hasher}; +use plonky2_field::extension_field::Extendable; +use plonky2_field::field_types::Field; + +use crate::curve::curve_types::{Curve, CurveScalar}; +use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; +use crate::gadgets::curve_windowed_mul::CircuitBuilderWindowedMul; +use crate::gadgets::nonnative::NonNativeTarget; +use crate::gadgets::split_nonnative::CircuitBuilderSplit; + +/// Computes `n*p + m*q` using windowed MSM, with a 2-bit window. +/// See Algorithm 9.23 in Handbook of Elliptic and Hyperelliptic Curve Cryptography for a +/// description. +/// Note: Doesn't work if `p == q`. +pub fn curve_msm_circuit, const D: usize>( + builder: &mut CircuitBuilder, + p: &AffinePointTarget, + q: &AffinePointTarget, + n: &NonNativeTarget, + m: &NonNativeTarget, +) -> AffinePointTarget { + let limbs_n = builder.split_nonnative_to_2_bit_limbs(n); + let limbs_m = builder.split_nonnative_to_2_bit_limbs(m); + assert_eq!(limbs_n.len(), limbs_m.len()); + let num_limbs = limbs_n.len(); + + let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); + let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( + &GenericHashOut::::to_bytes(&hash_0), + )); + let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); + let rando_t = builder.constant_affine_point(rando); + let neg_rando = builder.constant_affine_point(-rando); + + // Precomputes `precomputation[i + 4*j] = i*p + j*q` for `i,j=0..4`. + let mut precomputation = vec![p.clone(); 16]; + let mut cur_p = rando_t.clone(); + let mut cur_q = rando_t.clone(); + for i in 0..4 { + precomputation[i] = cur_p.clone(); + precomputation[4 * i] = cur_q.clone(); + cur_p = builder.curve_add(&cur_p, p); + cur_q = builder.curve_add(&cur_q, q); + } + for i in 1..4 { + precomputation[i] = builder.curve_add(&precomputation[i], &neg_rando); + precomputation[4 * i] = builder.curve_add(&precomputation[4 * i], &neg_rando); + } + for i in 1..4 { + for j in 1..4 { + precomputation[i + 4 * j] = + builder.curve_add(&precomputation[i], &precomputation[4 * j]); + } + } + + let four = builder.constant(F::from_canonical_usize(4)); + + let zero = builder.zero(); + let mut result = rando_t; + for (limb_n, limb_m) in limbs_n.into_iter().zip(limbs_m).rev() { + result = builder.curve_repeated_double(&result, 2); + let index = builder.mul_add(four, limb_m, limb_n); + let r = builder.random_access_curve_points(index, precomputation.clone()); + let is_zero = builder.is_equal(index, zero); + let should_add = builder.not(is_zero); + result = builder.curve_conditional_add(&result, &r, should_add); + } + let starting_point_multiplied = (0..2 * num_limbs).fold(rando, |acc, _| acc.double()); + let to_add = builder.constant_affine_point(-starting_point_multiplied); + result = builder.curve_add(&result, &to_add); + + result +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::secp256k1::Secp256K1; + use crate::gadgets::curve::CircuitBuilderCurve; + use crate::gadgets::curve_msm::curve_msm_circuit; + use crate::gadgets::nonnative::CircuitBuilderNonNative; + + #[test] + #[ignore] + fn test_curve_msm() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let p = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let q = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let n = Secp256K1Scalar::rand(); + let m = Secp256K1Scalar::rand(); + + let res = + (CurveScalar(n) * p.to_projective() + CurveScalar(m) * q.to_projective()).to_affine(); + let res_expected = builder.constant_affine_point(res); + builder.curve_assert_valid(&res_expected); + + let p_target = builder.constant_affine_point(p); + let q_target = builder.constant_affine_point(q); + let n_target = builder.constant_nonnative(n); + let m_target = builder.constant_nonnative(m); + + let res_target = + curve_msm_circuit(&mut builder, &p_target, &q_target, &n_target, &m_target); + builder.curve_assert_valid(&res_target); + + builder.connect_affine_point(&res_target, &res_expected); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } +} diff --git a/ecdsa/src/gadgets/curve_windowed_mul.rs b/ecdsa/src/gadgets/curve_windowed_mul.rs new file mode 100644 index 00000000..05c2c58a --- /dev/null +++ b/ecdsa/src/gadgets/curve_windowed_mul.rs @@ -0,0 +1,256 @@ +use std::marker::PhantomData; + +use num::BigUint; +use plonky2::gadgets::arithmetic_u32::U32Target; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::keccak::KeccakHash; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::config::{GenericHashOut, Hasher}; +use plonky2_field::extension_field::Extendable; +use plonky2_field::field_types::Field; + +use crate::curve::curve_types::{Curve, CurveScalar}; +use crate::gadgets::biguint::BigUintTarget; +use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; +use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; +use crate::gadgets::split_nonnative::CircuitBuilderSplit; + +const WINDOW_SIZE: usize = 4; + +pub trait CircuitBuilderWindowedMul, const D: usize> { + fn precompute_window( + &mut self, + p: &AffinePointTarget, + ) -> Vec>; + + fn random_access_curve_points( + &mut self, + access_index: Target, + v: Vec>, + ) -> AffinePointTarget; + + fn if_affine_point( + &mut self, + b: BoolTarget, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + ) -> AffinePointTarget; + + fn curve_scalar_mul_windowed( + &mut self, + p: &AffinePointTarget, + n: &NonNativeTarget, + ) -> AffinePointTarget; +} + +impl, const D: usize> CircuitBuilderWindowedMul + for CircuitBuilder +{ + fn precompute_window( + &mut self, + p: &AffinePointTarget, + ) -> Vec> { + let g = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine(); + let neg = { + let mut neg = g; + neg.y = -neg.y; + self.constant_affine_point(neg) + }; + + let mut multiples = vec![self.constant_affine_point(g)]; + for i in 1..1 << WINDOW_SIZE { + multiples.push(self.curve_add(p, &multiples[i - 1])); + } + for i in 1..1 << WINDOW_SIZE { + multiples[i] = self.curve_add(&neg, &multiples[i]); + } + multiples + } + + fn random_access_curve_points( + &mut self, + access_index: Target, + v: Vec>, + ) -> AffinePointTarget { + let num_limbs = C::BaseField::BITS / 32; + let zero = self.zero_u32(); + let x_limbs: Vec> = (0..num_limbs) + .map(|i| { + v.iter() + .map(|p| p.x.value.limbs.get(i).unwrap_or(&zero).0) + .collect() + }) + .collect(); + let y_limbs: Vec> = (0..num_limbs) + .map(|i| { + v.iter() + .map(|p| p.y.value.limbs.get(i).unwrap_or(&zero).0) + .collect() + }) + .collect(); + + let selected_x_limbs: Vec<_> = x_limbs + .iter() + .map(|limbs| U32Target(self.random_access(access_index, limbs.clone()))) + .collect(); + let selected_y_limbs: Vec<_> = y_limbs + .iter() + .map(|limbs| U32Target(self.random_access(access_index, limbs.clone()))) + .collect(); + + let x = NonNativeTarget { + value: BigUintTarget { + limbs: selected_x_limbs, + }, + _phantom: PhantomData, + }; + let y = NonNativeTarget { + value: BigUintTarget { + limbs: selected_y_limbs, + }, + _phantom: PhantomData, + }; + AffinePointTarget { x, y } + } + + fn if_affine_point( + &mut self, + b: BoolTarget, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + ) -> AffinePointTarget { + let new_x = self.if_nonnative(b, &p1.x, &p2.x); + let new_y = self.if_nonnative(b, &p1.y, &p2.y); + AffinePointTarget { x: new_x, y: new_y } + } + + fn curve_scalar_mul_windowed( + &mut self, + p: &AffinePointTarget, + n: &NonNativeTarget, + ) -> AffinePointTarget { + let hash_0 = KeccakHash::<25>::hash_no_pad(&[F::ZERO]); + let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( + &GenericHashOut::::to_bytes(&hash_0), + )); + let starting_point = CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE; + let starting_point_multiplied = { + let mut cur = starting_point; + for _ in 0..C::ScalarField::BITS { + cur = cur.double(); + } + cur + }; + + let mut result = self.constant_affine_point(starting_point.to_affine()); + + let precomputation = self.precompute_window(p); + let zero = self.zero(); + + let windows = self.split_nonnative_to_4_bit_limbs(n); + for i in (0..windows.len()).rev() { + result = self.curve_repeated_double(&result, WINDOW_SIZE); + let window = windows[i]; + + let to_add = self.random_access_curve_points(window, precomputation.clone()); + let is_zero = self.is_equal(window, zero); + let should_add = self.not(is_zero); + result = self.curve_conditional_add(&result, &to_add, should_add); + } + + let to_subtract = self.constant_affine_point(starting_point_multiplied.to_affine()); + let to_add = self.curve_neg(&to_subtract); + result = self.curve_add(&result, &to_add); + + result + } +} + +#[cfg(test)] +mod tests { + use std::ops::Neg; + + use anyhow::Result; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + use rand::Rng; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::secp256k1::Secp256K1; + use crate::gadgets::curve::CircuitBuilderCurve; + use crate::gadgets::curve_windowed_mul::CircuitBuilderWindowedMul; + use crate::gadgets::nonnative::CircuitBuilderNonNative; + + #[test] + fn test_random_access_curve_points() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let num_points = 16; + let points: Vec<_> = (0..num_points) + .map(|_| { + let g = (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE) + .to_affine(); + builder.constant_affine_point(g) + }) + .collect(); + + let mut rng = rand::thread_rng(); + let access_index = rng.gen::() % num_points; + + let access_index_target = builder.constant(F::from_canonical_usize(access_index)); + let selected = builder.random_access_curve_points(access_index_target, points.clone()); + let expected = points[access_index].clone(); + builder.connect_affine_point(&selected, &expected); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } + + #[test] + #[ignore] + fn test_curve_windowed_mul() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let five = Secp256K1Scalar::from_canonical_usize(5); + let neg_five = five.neg(); + let neg_five_scalar = CurveScalar::(neg_five); + let neg_five_g = (neg_five_scalar * g.to_projective()).to_affine(); + let neg_five_g_expected = builder.constant_affine_point(neg_five_g); + builder.curve_assert_valid(&neg_five_g_expected); + + let g_target = builder.constant_affine_point(g); + let neg_five_target = builder.constant_nonnative(neg_five); + let neg_five_g_actual = builder.curve_scalar_mul_windowed(&g_target, &neg_five_target); + builder.curve_assert_valid(&neg_five_g_actual); + + builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } +} diff --git a/ecdsa/src/gadgets/ecdsa.rs b/ecdsa/src/gadgets/ecdsa.rs new file mode 100644 index 00000000..afd79c61 --- /dev/null +++ b/ecdsa/src/gadgets/ecdsa.rs @@ -0,0 +1,117 @@ +use std::marker::PhantomData; + +use plonky2::hash::hash_types::RichField; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2_field::extension_field::Extendable; +use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + +use crate::curve::curve_types::Curve; +use crate::curve::secp256k1::Secp256K1; +use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; +use crate::gadgets::curve_fixed_base::fixed_base_curve_mul_circuit; +use crate::gadgets::glv::CircuitBuilderGlv; +use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; + +#[derive(Clone, Debug)] +pub struct ECDSASecretKeyTarget(NonNativeTarget); + +#[derive(Clone, Debug)] +pub struct ECDSAPublicKeyTarget(AffinePointTarget); + +#[derive(Clone, Debug)] +pub struct ECDSASignatureTarget { + pub r: NonNativeTarget, + pub s: NonNativeTarget, +} + +pub fn verify_message_circuit, const D: usize>( + builder: &mut CircuitBuilder, + msg: NonNativeTarget, + sig: ECDSASignatureTarget, + pk: ECDSAPublicKeyTarget, +) { + let ECDSASignatureTarget { r, s } = sig; + + builder.curve_assert_valid(&pk.0); + + let c = builder.inv_nonnative(&s); + let u1 = builder.mul_nonnative(&msg, &c); + let u2 = builder.mul_nonnative(&r, &c); + + let point1 = fixed_base_curve_mul_circuit(builder, Secp256K1::GENERATOR_AFFINE, &u1); + let point2 = builder.glv_mul(&pk.0, &u2); + let point = builder.curve_add(&point1, &point2); + + let x = NonNativeTarget:: { + value: point.x.value, + _phantom: PhantomData, + }; + builder.connect_nonnative(&r, &x); +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use super::{ECDSAPublicKeyTarget, ECDSASignatureTarget}; + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::ecdsa::{sign_message, ECDSAPublicKey, ECDSASecretKey, ECDSASignature}; + use crate::curve::secp256k1::Secp256K1; + use crate::gadgets::curve::CircuitBuilderCurve; + use crate::gadgets::ecdsa::verify_message_circuit; + use crate::gadgets::nonnative::CircuitBuilderNonNative; + + fn test_ecdsa_circuit_with_config(config: CircuitConfig) -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + type Curve = Secp256K1; + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let msg = Secp256K1Scalar::rand(); + let msg_target = builder.constant_nonnative(msg); + + let sk = ECDSASecretKey::(Secp256K1Scalar::rand()); + let pk = ECDSAPublicKey((CurveScalar(sk.0) * Curve::GENERATOR_PROJECTIVE).to_affine()); + + let pk_target = ECDSAPublicKeyTarget(builder.constant_affine_point(pk.0)); + + let sig = sign_message(msg, sk); + + let ECDSASignature { r, s } = sig; + let r_target = builder.constant_nonnative(r); + let s_target = builder.constant_nonnative(s); + let sig_target = ECDSASignatureTarget { + r: r_target, + s: s_target, + }; + + verify_message_circuit(&mut builder, msg_target, sig_target, pk_target); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + #[ignore] + fn test_ecdsa_circuit_narrow() -> Result<()> { + test_ecdsa_circuit_with_config(CircuitConfig::standard_ecc_config()) + } + + #[test] + #[ignore] + fn test_ecdsa_circuit_wide() -> Result<()> { + test_ecdsa_circuit_with_config(CircuitConfig::wide_ecc_config()) + } +} diff --git a/ecdsa/src/gadgets/glv.rs b/ecdsa/src/gadgets/glv.rs new file mode 100644 index 00000000..c567e5ab --- /dev/null +++ b/ecdsa/src/gadgets/glv.rs @@ -0,0 +1,180 @@ +use std::marker::PhantomData; + +use plonky2::hash::hash_types::RichField; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::iop::witness::PartitionWitness; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2_field::extension_field::Extendable; +use plonky2_field::field_types::{Field, PrimeField}; +use plonky2_field::secp256k1_base::Secp256K1Base; +use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + +use crate::curve::glv::{decompose_secp256k1_scalar, GLV_BETA, GLV_S}; +use crate::curve::secp256k1::Secp256K1; +use crate::gadgets::biguint::{buffer_set_biguint_target, witness_get_biguint_target}; +use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; +use crate::gadgets::curve_msm::curve_msm_circuit; +use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; + +pub trait CircuitBuilderGlv, const D: usize> { + fn secp256k1_glv_beta(&mut self) -> NonNativeTarget; + + fn decompose_secp256k1_scalar( + &mut self, + k: &NonNativeTarget, + ) -> ( + NonNativeTarget, + NonNativeTarget, + BoolTarget, + BoolTarget, + ); + + fn glv_mul( + &mut self, + p: &AffinePointTarget, + k: &NonNativeTarget, + ) -> AffinePointTarget; +} + +impl, const D: usize> CircuitBuilderGlv + for CircuitBuilder +{ + fn secp256k1_glv_beta(&mut self) -> NonNativeTarget { + self.constant_nonnative(GLV_BETA) + } + + fn decompose_secp256k1_scalar( + &mut self, + k: &NonNativeTarget, + ) -> ( + NonNativeTarget, + NonNativeTarget, + BoolTarget, + BoolTarget, + ) { + let k1 = self.add_virtual_nonnative_target_sized::(4); + let k2 = self.add_virtual_nonnative_target_sized::(4); + let k1_neg = self.add_virtual_bool_target(); + let k2_neg = self.add_virtual_bool_target(); + + self.add_simple_generator(GLVDecompositionGenerator:: { + k: k.clone(), + k1: k1.clone(), + k2: k2.clone(), + k1_neg, + k2_neg, + _phantom: PhantomData, + }); + + // Check that `k1_raw + GLV_S * k2_raw == k`. + let k1_raw = self.nonnative_conditional_neg(&k1, k1_neg); + let k2_raw = self.nonnative_conditional_neg(&k2, k2_neg); + let s = self.constant_nonnative(GLV_S); + let mut should_be_k = self.mul_nonnative(&s, &k2_raw); + should_be_k = self.add_nonnative(&should_be_k, &k1_raw); + self.connect_nonnative(&should_be_k, k); + + (k1, k2, k1_neg, k2_neg) + } + + fn glv_mul( + &mut self, + p: &AffinePointTarget, + k: &NonNativeTarget, + ) -> AffinePointTarget { + let (k1, k2, k1_neg, k2_neg) = self.decompose_secp256k1_scalar(k); + + let beta = self.secp256k1_glv_beta(); + let beta_px = self.mul_nonnative(&beta, &p.x); + let sp = AffinePointTarget:: { + x: beta_px, + y: p.y.clone(), + }; + + let p_neg = self.curve_conditional_neg(p, k1_neg); + let sp_neg = self.curve_conditional_neg(&sp, k2_neg); + curve_msm_circuit(self, &p_neg, &sp_neg, &k1, &k2) + } +} + +#[derive(Debug)] +struct GLVDecompositionGenerator, const D: usize> { + k: NonNativeTarget, + k1: NonNativeTarget, + k2: NonNativeTarget, + k1_neg: BoolTarget, + k2_neg: BoolTarget, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for GLVDecompositionGenerator +{ + fn dependencies(&self) -> Vec { + self.k.value.limbs.iter().map(|l| l.0).collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let k = Secp256K1Scalar::from_biguint(witness_get_biguint_target( + witness, + self.k.value.clone(), + )); + + let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); + + buffer_set_biguint_target(out_buffer, &self.k1.value, &k1.to_canonical_biguint()); + buffer_set_biguint_target(out_buffer, &self.k2.value, &k2.to_canonical_biguint()); + out_buffer.set_bool_target(self.k1_neg, k1_neg); + out_buffer.set_bool_target(self.k2_neg, k2_neg); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::glv::glv_mul; + use crate::curve::secp256k1::Secp256K1; + use crate::gadgets::curve::CircuitBuilderCurve; + use crate::gadgets::glv::CircuitBuilderGlv; + use crate::gadgets::nonnative::CircuitBuilderNonNative; + + #[test] + #[ignore] + fn test_glv_gadget() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let rando = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let randot = builder.constant_affine_point(rando); + + let scalar = Secp256K1Scalar::rand(); + let scalar_target = builder.constant_nonnative(scalar); + + let rando_glv_scalar = glv_mul(rando.to_projective(), scalar); + let expected = builder.constant_affine_point(rando_glv_scalar.to_affine()); + let actual = builder.glv_mul(&randot, &scalar_target); + builder.connect_affine_point(&expected, &actual); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } +} diff --git a/ecdsa/src/gadgets/nonnative.rs b/ecdsa/src/gadgets/nonnative.rs new file mode 100644 index 00000000..76f23f3f --- /dev/null +++ b/ecdsa/src/gadgets/nonnative.rs @@ -0,0 +1,822 @@ +use std::marker::PhantomData; + +use num::{BigUint, Integer, One, Zero}; +use plonky2::gadgets::arithmetic_u32::U32Target; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::iop::witness::PartitionWitness; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2_field::field_types::PrimeField; +use plonky2_field::{extension_field::Extendable, field_types::Field}; +use plonky2_util::ceil_div_usize; + +use crate::gadgets::biguint::{ + buffer_set_biguint_target, witness_get_biguint_target, BigUintTarget, CircuitBuilderBiguint, +}; + +#[derive(Clone, Debug)] +pub struct NonNativeTarget { + pub(crate) value: BigUintTarget, + pub(crate) _phantom: PhantomData, +} + +pub trait CircuitBuilderNonNative, const D: usize> { + fn num_nonnative_limbs() -> usize { + ceil_div_usize(FF::BITS, 32) + } + + fn biguint_to_nonnative(&mut self, x: &BigUintTarget) -> NonNativeTarget; + + fn nonnative_to_canonical_biguint( + &mut self, + x: &NonNativeTarget, + ) -> BigUintTarget; + + fn constant_nonnative(&mut self, x: FF) -> NonNativeTarget; + + fn zero_nonnative(&mut self) -> NonNativeTarget; + + // Assert that two NonNativeTarget's, both assumed to be in reduced form, are equal. + fn connect_nonnative( + &mut self, + lhs: &NonNativeTarget, + rhs: &NonNativeTarget, + ); + + fn add_virtual_nonnative_target(&mut self) -> NonNativeTarget; + + fn add_virtual_nonnative_target_sized( + &mut self, + num_limbs: usize, + ) -> NonNativeTarget; + + fn add_nonnative( + &mut self, + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget; + + fn mul_nonnative_by_bool( + &mut self, + a: &NonNativeTarget, + b: BoolTarget, + ) -> NonNativeTarget; + + fn if_nonnative( + &mut self, + b: BoolTarget, + x: &NonNativeTarget, + y: &NonNativeTarget, + ) -> NonNativeTarget; + + fn add_many_nonnative( + &mut self, + to_add: &[NonNativeTarget], + ) -> NonNativeTarget; + + // Subtract two `NonNativeTarget`s. + fn sub_nonnative( + &mut self, + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget; + + fn mul_nonnative( + &mut self, + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget; + + fn mul_many_nonnative( + &mut self, + to_mul: &[NonNativeTarget], + ) -> NonNativeTarget; + + fn neg_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget; + + fn inv_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget; + + /// Returns `x % |FF|` as a `NonNativeTarget`. + fn reduce(&mut self, x: &BigUintTarget) -> NonNativeTarget; + + fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget; + + fn bool_to_nonnative(&mut self, b: &BoolTarget) -> NonNativeTarget; + + // Split a nonnative field element to bits. + fn split_nonnative_to_bits(&mut self, x: &NonNativeTarget) -> Vec; + + fn nonnative_conditional_neg( + &mut self, + x: &NonNativeTarget, + b: BoolTarget, + ) -> NonNativeTarget; +} + +impl, const D: usize> CircuitBuilderNonNative + for CircuitBuilder +{ + fn num_nonnative_limbs() -> usize { + ceil_div_usize(FF::BITS, 32) + } + + fn biguint_to_nonnative(&mut self, x: &BigUintTarget) -> NonNativeTarget { + NonNativeTarget { + value: x.clone(), + _phantom: PhantomData, + } + } + + fn nonnative_to_canonical_biguint( + &mut self, + x: &NonNativeTarget, + ) -> BigUintTarget { + x.value.clone() + } + + fn constant_nonnative(&mut self, x: FF) -> NonNativeTarget { + let x_biguint = self.constant_biguint(&x.to_canonical_biguint()); + self.biguint_to_nonnative(&x_biguint) + } + + fn zero_nonnative(&mut self) -> NonNativeTarget { + self.constant_nonnative(FF::ZERO) + } + + // Assert that two NonNativeTarget's, both assumed to be in reduced form, are equal. + fn connect_nonnative( + &mut self, + lhs: &NonNativeTarget, + rhs: &NonNativeTarget, + ) { + self.connect_biguint(&lhs.value, &rhs.value); + } + + fn add_virtual_nonnative_target(&mut self) -> NonNativeTarget { + let num_limbs = Self::num_nonnative_limbs::(); + let value = self.add_virtual_biguint_target(num_limbs); + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + + fn add_virtual_nonnative_target_sized( + &mut self, + num_limbs: usize, + ) -> NonNativeTarget { + let value = self.add_virtual_biguint_target(num_limbs); + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + + fn add_nonnative( + &mut self, + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget { + let sum = self.add_virtual_nonnative_target::(); + let overflow = self.add_virtual_bool_target(); + + self.add_simple_generator(NonNativeAdditionGenerator:: { + a: a.clone(), + b: b.clone(), + sum: sum.clone(), + overflow, + _phantom: PhantomData, + }); + + let sum_expected = self.add_biguint(&a.value, &b.value); + + let modulus = self.constant_biguint(&FF::order()); + let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow); + let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow); + self.connect_biguint(&sum_expected, &sum_actual); + + // Range-check result. + // TODO: can potentially leave unreduced until necessary (e.g. when connecting values). + let cmp = self.cmp_biguint(&sum.value, &modulus); + let one = self.one(); + self.connect(cmp.target, one); + + sum + } + + fn mul_nonnative_by_bool( + &mut self, + a: &NonNativeTarget, + b: BoolTarget, + ) -> NonNativeTarget { + NonNativeTarget { + value: self.mul_biguint_by_bool(&a.value, b), + _phantom: PhantomData, + } + } + + fn if_nonnative( + &mut self, + b: BoolTarget, + x: &NonNativeTarget, + y: &NonNativeTarget, + ) -> NonNativeTarget { + let not_b = self.not(b); + let maybe_x = self.mul_nonnative_by_bool(x, b); + let maybe_y = self.mul_nonnative_by_bool(y, not_b); + self.add_nonnative(&maybe_x, &maybe_y) + } + + fn add_many_nonnative( + &mut self, + to_add: &[NonNativeTarget], + ) -> NonNativeTarget { + if to_add.len() == 1 { + return to_add[0].clone(); + } + + let sum = self.add_virtual_nonnative_target::(); + let overflow = self.add_virtual_u32_target(); + let summands = to_add.to_vec(); + + self.add_simple_generator(NonNativeMultipleAddsGenerator:: { + summands: summands.clone(), + sum: sum.clone(), + overflow, + _phantom: PhantomData, + }); + + self.range_check_u32(sum.value.limbs.clone()); + self.range_check_u32(vec![overflow]); + + let sum_expected = summands + .iter() + .fold(self.zero_biguint(), |a, b| self.add_biguint(&a, &b.value)); + + let modulus = self.constant_biguint(&FF::order()); + let overflow_biguint = BigUintTarget { + limbs: vec![overflow], + }; + let mod_times_overflow = self.mul_biguint(&modulus, &overflow_biguint); + let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow); + self.connect_biguint(&sum_expected, &sum_actual); + + // Range-check result. + // TODO: can potentially leave unreduced until necessary (e.g. when connecting values). + let cmp = self.cmp_biguint(&sum.value, &modulus); + let one = self.one(); + self.connect(cmp.target, one); + + sum + } + + // Subtract two `NonNativeTarget`s. + fn sub_nonnative( + &mut self, + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget { + let diff = self.add_virtual_nonnative_target::(); + let overflow = self.add_virtual_bool_target(); + + self.add_simple_generator(NonNativeSubtractionGenerator:: { + a: a.clone(), + b: b.clone(), + diff: diff.clone(), + overflow, + _phantom: PhantomData, + }); + + self.range_check_u32(diff.value.limbs.clone()); + self.assert_bool(overflow); + + let diff_plus_b = self.add_biguint(&diff.value, &b.value); + let modulus = self.constant_biguint(&FF::order()); + let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow); + let diff_plus_b_reduced = self.sub_biguint(&diff_plus_b, &mod_times_overflow); + self.connect_biguint(&a.value, &diff_plus_b_reduced); + + diff + } + + fn mul_nonnative( + &mut self, + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget { + let prod = self.add_virtual_nonnative_target::(); + let modulus = self.constant_biguint(&FF::order()); + let overflow = self.add_virtual_biguint_target( + a.value.num_limbs() + b.value.num_limbs() - modulus.num_limbs(), + ); + + self.add_simple_generator(NonNativeMultiplicationGenerator:: { + a: a.clone(), + b: b.clone(), + prod: prod.clone(), + overflow: overflow.clone(), + _phantom: PhantomData, + }); + + self.range_check_u32(prod.value.limbs.clone()); + self.range_check_u32(overflow.limbs.clone()); + + let prod_expected = self.mul_biguint(&a.value, &b.value); + + let mod_times_overflow = self.mul_biguint(&modulus, &overflow); + let prod_actual = self.add_biguint(&prod.value, &mod_times_overflow); + self.connect_biguint(&prod_expected, &prod_actual); + + prod + } + + fn mul_many_nonnative( + &mut self, + to_mul: &[NonNativeTarget], + ) -> NonNativeTarget { + if to_mul.len() == 1 { + return to_mul[0].clone(); + } + + let mut accumulator = self.mul_nonnative(&to_mul[0], &to_mul[1]); + for t in to_mul.iter().skip(2) { + accumulator = self.mul_nonnative(&accumulator, t); + } + accumulator + } + + fn neg_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { + let zero_target = self.constant_biguint(&BigUint::zero()); + let zero_ff = self.biguint_to_nonnative(&zero_target); + + self.sub_nonnative(&zero_ff, x) + } + + fn inv_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { + let num_limbs = x.value.num_limbs(); + let inv_biguint = self.add_virtual_biguint_target(num_limbs); + let div = self.add_virtual_biguint_target(num_limbs); + + self.add_simple_generator(NonNativeInverseGenerator:: { + x: x.clone(), + inv: inv_biguint.clone(), + div: div.clone(), + _phantom: PhantomData, + }); + + let product = self.mul_biguint(&x.value, &inv_biguint); + + let modulus = self.constant_biguint(&FF::order()); + let mod_times_div = self.mul_biguint(&modulus, &div); + let one = self.constant_biguint(&BigUint::one()); + let expected_product = self.add_biguint(&mod_times_div, &one); + self.connect_biguint(&product, &expected_product); + + NonNativeTarget:: { + value: inv_biguint, + _phantom: PhantomData, + } + } + + /// Returns `x % |FF|` as a `NonNativeTarget`. + fn reduce(&mut self, x: &BigUintTarget) -> NonNativeTarget { + let modulus = FF::order(); + let order_target = self.constant_biguint(&modulus); + let value = self.rem_biguint(x, &order_target); + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + + fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { + let x_biguint = self.nonnative_to_canonical_biguint(x); + self.reduce(&x_biguint) + } + + fn bool_to_nonnative(&mut self, b: &BoolTarget) -> NonNativeTarget { + let limbs = vec![U32Target(b.target)]; + let value = BigUintTarget { limbs }; + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + + // Split a nonnative field element to bits. + fn split_nonnative_to_bits(&mut self, x: &NonNativeTarget) -> Vec { + let num_limbs = x.value.num_limbs(); + let mut result = Vec::with_capacity(num_limbs * 32); + + for i in 0..num_limbs { + let limb = x.value.get_limb(i); + let bit_targets = self.split_le_base::<2>(limb.0, 32); + let mut bits: Vec<_> = bit_targets + .iter() + .map(|&t| BoolTarget::new_unsafe(t)) + .collect(); + + result.append(&mut bits); + } + + result + } + + fn nonnative_conditional_neg( + &mut self, + x: &NonNativeTarget, + b: BoolTarget, + ) -> NonNativeTarget { + let not_b = self.not(b); + let neg = self.neg_nonnative(x); + let x_if_true = self.mul_nonnative_by_bool(&neg, b); + let x_if_false = self.mul_nonnative_by_bool(x, not_b); + + self.add_nonnative(&x_if_true, &x_if_false) + } +} + +#[derive(Debug)] +struct NonNativeAdditionGenerator, const D: usize, FF: PrimeField> { + a: NonNativeTarget, + b: NonNativeTarget, + sum: NonNativeTarget, + overflow: BoolTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: PrimeField> SimpleGenerator + for NonNativeAdditionGenerator +{ + fn dependencies(&self) -> Vec { + self.a + .value + .limbs + .iter() + .cloned() + .chain(self.b.value.limbs.clone()) + .map(|l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = FF::from_biguint(witness_get_biguint_target(witness, self.a.value.clone())); + let b = FF::from_biguint(witness_get_biguint_target(witness, self.b.value.clone())); + let a_biguint = a.to_canonical_biguint(); + let b_biguint = b.to_canonical_biguint(); + let sum_biguint = a_biguint + b_biguint; + let modulus = FF::order(); + let (overflow, sum_reduced) = if sum_biguint > modulus { + (true, sum_biguint - modulus) + } else { + (false, sum_biguint) + }; + + buffer_set_biguint_target(out_buffer, &self.sum.value, &sum_reduced); + out_buffer.set_bool_target(self.overflow, overflow); + } +} + +#[derive(Debug)] +struct NonNativeMultipleAddsGenerator, const D: usize, FF: PrimeField> +{ + summands: Vec>, + sum: NonNativeTarget, + overflow: U32Target, + _phantom: PhantomData, +} + +impl, const D: usize, FF: PrimeField> SimpleGenerator + for NonNativeMultipleAddsGenerator +{ + fn dependencies(&self) -> Vec { + self.summands + .iter() + .flat_map(|summand| summand.value.limbs.iter().map(|limb| limb.0)) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let summands: Vec<_> = self + .summands + .iter() + .map(|summand| { + FF::from_biguint(witness_get_biguint_target(witness, summand.value.clone())) + }) + .collect(); + let summand_biguints: Vec<_> = summands + .iter() + .map(|summand| summand.to_canonical_biguint()) + .collect(); + + let sum_biguint = summand_biguints + .iter() + .fold(BigUint::zero(), |a, b| a + b.clone()); + + let modulus = FF::order(); + let (overflow_biguint, sum_reduced) = sum_biguint.div_rem(&modulus); + let overflow = overflow_biguint.to_u64_digits()[0] as u32; + + buffer_set_biguint_target(out_buffer, &self.sum.value, &sum_reduced); + out_buffer.set_u32_target(self.overflow, overflow); + } +} + +#[derive(Debug)] +struct NonNativeSubtractionGenerator, const D: usize, FF: Field> { + a: NonNativeTarget, + b: NonNativeTarget, + diff: NonNativeTarget, + overflow: BoolTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: PrimeField> SimpleGenerator + for NonNativeSubtractionGenerator +{ + fn dependencies(&self) -> Vec { + self.a + .value + .limbs + .iter() + .cloned() + .chain(self.b.value.limbs.clone()) + .map(|l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = FF::from_biguint(witness_get_biguint_target(witness, self.a.value.clone())); + let b = FF::from_biguint(witness_get_biguint_target(witness, self.b.value.clone())); + let a_biguint = a.to_canonical_biguint(); + let b_biguint = b.to_canonical_biguint(); + + let modulus = FF::order(); + let (diff_biguint, overflow) = if a_biguint >= b_biguint { + (a_biguint - b_biguint, false) + } else { + (modulus + a_biguint - b_biguint, true) + }; + + buffer_set_biguint_target(out_buffer, &self.diff.value, &diff_biguint); + out_buffer.set_bool_target(self.overflow, overflow); + } +} + +#[derive(Debug)] +struct NonNativeMultiplicationGenerator, const D: usize, FF: Field> { + a: NonNativeTarget, + b: NonNativeTarget, + prod: NonNativeTarget, + overflow: BigUintTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: PrimeField> SimpleGenerator + for NonNativeMultiplicationGenerator +{ + fn dependencies(&self) -> Vec { + self.a + .value + .limbs + .iter() + .cloned() + .chain(self.b.value.limbs.clone()) + .map(|l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = FF::from_biguint(witness_get_biguint_target(witness, self.a.value.clone())); + let b = FF::from_biguint(witness_get_biguint_target(witness, self.b.value.clone())); + let a_biguint = a.to_canonical_biguint(); + let b_biguint = b.to_canonical_biguint(); + + let prod_biguint = a_biguint * b_biguint; + + let modulus = FF::order(); + let (overflow_biguint, prod_reduced) = prod_biguint.div_rem(&modulus); + + buffer_set_biguint_target(out_buffer, &self.prod.value, &prod_reduced); + buffer_set_biguint_target(out_buffer, &self.overflow, &overflow_biguint); + } +} + +#[derive(Debug)] +struct NonNativeInverseGenerator, const D: usize, FF: PrimeField> { + x: NonNativeTarget, + inv: BigUintTarget, + div: BigUintTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: PrimeField> SimpleGenerator + for NonNativeInverseGenerator +{ + fn dependencies(&self) -> Vec { + self.x.value.limbs.iter().map(|&l| l.0).collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let x = FF::from_biguint(witness_get_biguint_target(witness, self.x.value.clone())); + let inv = x.inverse(); + + let x_biguint = x.to_canonical_biguint(); + let inv_biguint = inv.to_canonical_biguint(); + let prod = x_biguint * &inv_biguint; + let modulus = FF::order(); + let (div, _rem) = prod.div_rem(&modulus); + + buffer_set_biguint_target(out_buffer, &self.div, &div); + buffer_set_biguint_target(out_buffer, &self.inv, &inv_biguint); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2_field::field_types::{Field, PrimeField}; + use plonky2_field::secp256k1_base::Secp256K1Base; + + use crate::gadgets::nonnative::CircuitBuilderNonNative; + + #[test] + fn test_nonnative_add() -> Result<()> { + type FF = Secp256K1Base; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let x_ff = FF::rand(); + let y_ff = FF::rand(); + let sum_ff = x_ff + y_ff; + + let config = CircuitConfig::standard_ecc_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let y = builder.constant_nonnative(y_ff); + let sum = builder.add_nonnative(&x, &y); + + let sum_expected = builder.constant_nonnative(sum_ff); + builder.connect_nonnative(&sum, &sum_expected); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_nonnative_many_adds() -> Result<()> { + type FF = Secp256K1Base; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let a_ff = FF::rand(); + let b_ff = FF::rand(); + let c_ff = FF::rand(); + let d_ff = FF::rand(); + let e_ff = FF::rand(); + let f_ff = FF::rand(); + let g_ff = FF::rand(); + let h_ff = FF::rand(); + let sum_ff = a_ff + b_ff + c_ff + d_ff + e_ff + f_ff + g_ff + h_ff; + + let config = CircuitConfig::standard_ecc_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let a = builder.constant_nonnative(a_ff); + let b = builder.constant_nonnative(b_ff); + let c = builder.constant_nonnative(c_ff); + let d = builder.constant_nonnative(d_ff); + let e = builder.constant_nonnative(e_ff); + let f = builder.constant_nonnative(f_ff); + let g = builder.constant_nonnative(g_ff); + let h = builder.constant_nonnative(h_ff); + let all = [a, b, c, d, e, f, g, h]; + let sum = builder.add_many_nonnative(&all); + + let sum_expected = builder.constant_nonnative(sum_ff); + builder.connect_nonnative(&sum, &sum_expected); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_nonnative_sub() -> Result<()> { + type FF = Secp256K1Base; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let x_ff = FF::rand(); + let mut y_ff = FF::rand(); + while y_ff.to_canonical_biguint() > x_ff.to_canonical_biguint() { + y_ff = FF::rand(); + } + let diff_ff = x_ff - y_ff; + + let config = CircuitConfig::standard_ecc_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let y = builder.constant_nonnative(y_ff); + let diff = builder.sub_nonnative(&x, &y); + + let diff_expected = builder.constant_nonnative(diff_ff); + builder.connect_nonnative(&diff, &diff_expected); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_nonnative_mul() -> Result<()> { + type FF = Secp256K1Base; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let x_ff = FF::rand(); + let y_ff = FF::rand(); + let product_ff = x_ff * y_ff; + + let config = CircuitConfig::standard_ecc_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let y = builder.constant_nonnative(y_ff); + let product = builder.mul_nonnative(&x, &y); + + let product_expected = builder.constant_nonnative(product_ff); + builder.connect_nonnative(&product, &product_expected); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_nonnative_neg() -> Result<()> { + type FF = Secp256K1Base; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let x_ff = FF::rand(); + let neg_x_ff = -x_ff; + + let config = CircuitConfig::standard_ecc_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let neg_x = builder.neg_nonnative(&x); + + let neg_x_expected = builder.constant_nonnative(neg_x_ff); + builder.connect_nonnative(&neg_x, &neg_x_expected); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_nonnative_inv() -> Result<()> { + type FF = Secp256K1Base; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let x_ff = FF::rand(); + let inv_x_ff = x_ff.inverse(); + + let config = CircuitConfig::standard_ecc_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let inv_x = builder.inv_nonnative(&x); + + let inv_x_expected = builder.constant_nonnative(inv_x_ff); + builder.connect_nonnative(&inv_x, &inv_x_expected); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } +} diff --git a/ecdsa/src/gadgets/split_nonnative.rs b/ecdsa/src/gadgets/split_nonnative.rs new file mode 100644 index 00000000..e79438f8 --- /dev/null +++ b/ecdsa/src/gadgets/split_nonnative.rs @@ -0,0 +1,131 @@ +use std::marker::PhantomData; + +use itertools::Itertools; +use plonky2::gadgets::arithmetic_u32::U32Target; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::target::Target; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2_field::extension_field::Extendable; +use plonky2_field::field_types::Field; + +use crate::gadgets::biguint::BigUintTarget; +use crate::gadgets::nonnative::NonNativeTarget; + +pub trait CircuitBuilderSplit, const D: usize> { + fn split_u32_to_4_bit_limbs(&mut self, val: U32Target) -> Vec; + + fn split_nonnative_to_4_bit_limbs( + &mut self, + val: &NonNativeTarget, + ) -> Vec; + + fn split_nonnative_to_2_bit_limbs( + &mut self, + val: &NonNativeTarget, + ) -> Vec; + + // Note: assumes its inputs are 4-bit limbs, and does not range-check. + fn recombine_nonnative_4_bit_limbs( + &mut self, + limbs: Vec, + ) -> NonNativeTarget; +} + +impl, const D: usize> CircuitBuilderSplit + for CircuitBuilder +{ + fn split_u32_to_4_bit_limbs(&mut self, val: U32Target) -> Vec { + let two_bit_limbs = self.split_le_base::<4>(val.0, 16); + let four = self.constant(F::from_canonical_usize(4)); + let combined_limbs = two_bit_limbs + .iter() + .tuples() + .map(|(&a, &b)| self.mul_add(b, four, a)) + .collect(); + + combined_limbs + } + + fn split_nonnative_to_4_bit_limbs( + &mut self, + val: &NonNativeTarget, + ) -> Vec { + val.value + .limbs + .iter() + .flat_map(|&l| self.split_u32_to_4_bit_limbs(l)) + .collect() + } + + fn split_nonnative_to_2_bit_limbs( + &mut self, + val: &NonNativeTarget, + ) -> Vec { + val.value + .limbs + .iter() + .flat_map(|&l| self.split_le_base::<4>(l.0, 16)) + .collect() + } + + // Note: assumes its inputs are 4-bit limbs, and does not range-check. + fn recombine_nonnative_4_bit_limbs( + &mut self, + limbs: Vec, + ) -> NonNativeTarget { + let base = self.constant_u32(1 << 4); + let u32_limbs = limbs + .chunks(8) + .map(|chunk| { + let mut combined_chunk = self.zero_u32(); + for i in (0..8).rev() { + let (low, _high) = self.mul_add_u32(combined_chunk, base, U32Target(chunk[i])); + combined_chunk = low; + } + combined_chunk + }) + .collect(); + + NonNativeTarget { + value: BigUintTarget { limbs: u32_limbs }, + _phantom: PhantomData, + } + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; + use crate::gadgets::split_nonnative::CircuitBuilderSplit; + + #[test] + fn test_split_nonnative() -> Result<()> { + type FF = Secp256K1Scalar; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = FF::rand(); + let x_target = builder.constant_nonnative(x); + let split = builder.split_nonnative_to_4_bit_limbs(&x_target); + let combined: NonNativeTarget = + builder.recombine_nonnative_4_bit_limbs(split); + builder.connect_nonnative(&x_target, &combined); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } +} diff --git a/ecdsa/src/lib.rs b/ecdsa/src/lib.rs new file mode 100644 index 00000000..1de32647 --- /dev/null +++ b/ecdsa/src/lib.rs @@ -0,0 +1,4 @@ +#![allow(clippy::needless_range_loop)] + +pub mod curve; +pub mod gadgets; From 786c1eafcfc714574fed2db4a226a12d7ae60d25 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 16 Mar 2022 05:44:02 +0100 Subject: [PATCH 54/56] Minor --- ecdsa/Cargo.toml | 4 +--- ecdsa/src/gadgets/biguint.rs | 9 +++------ ecdsa/src/gadgets/curve.rs | 3 +-- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/ecdsa/Cargo.toml b/ecdsa/Cargo.toml index 59ad9ee3..a7e1024b 100644 --- a/ecdsa/Cargo.toml +++ b/ecdsa/Cargo.toml @@ -14,6 +14,4 @@ itertools = "0.10.0" rayon = "1.5.1" serde = { version = "1.0", features = ["derive"] } anyhow = "1.0.40" -rand = "0.8.4" -#env_logger = "0.9.0" -#log = "0.4.14" +rand = "0.8.4" \ No newline at end of file diff --git a/ecdsa/src/gadgets/biguint.rs b/ecdsa/src/gadgets/biguint.rs index 5c077747..0c08814c 100644 --- a/ecdsa/src/gadgets/biguint.rs +++ b/ecdsa/src/gadgets/biguint.rs @@ -42,17 +42,17 @@ pub trait CircuitBuilderBiguint, const D: usize> { fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget; - // Add two `BigUintTarget`s. + /// Add two `BigUintTarget`s. fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; - // Subtract two `BigUintTarget`s. We assume that the first is larger than the second. + /// Subtract two `BigUintTarget`s. We assume that the first is larger than the second. fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; fn mul_biguint_by_bool(&mut self, a: &BigUintTarget, b: BoolTarget) -> BigUintTarget; - // Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function). + /// Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function). fn mul_add_biguint( &mut self, x: &BigUintTarget, @@ -133,7 +133,6 @@ impl, const D: usize> CircuitBuilderBiguint BigUintTarget { limbs } } - // Add two `BigUintTarget`s. fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let num_limbs = a.num_limbs().max(b.num_limbs()); @@ -158,7 +157,6 @@ impl, const D: usize> CircuitBuilderBiguint } } - // Subtract two `BigUintTarget`s. We assume that the first is larger than the second. fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let (a, b) = self.pad_biguints(a, b); let num_limbs = a.limbs.len(); @@ -216,7 +214,6 @@ impl, const D: usize> CircuitBuilderBiguint } } - // Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function). fn mul_add_biguint( &mut self, x: &BigUintTarget, diff --git a/ecdsa/src/gadgets/curve.rs b/ecdsa/src/gadgets/curve.rs index e8a277c4..e511973b 100644 --- a/ecdsa/src/gadgets/curve.rs +++ b/ecdsa/src/gadgets/curve.rs @@ -50,7 +50,7 @@ pub trait CircuitBuilderCurve, const D: usize> { n: usize, ) -> AffinePointTarget; - // Add two points, which are assumed to be non-equal. + /// Add two points, which are assumed to be non-equal. fn curve_add( &mut self, p1: &AffinePointTarget, @@ -169,7 +169,6 @@ impl, const D: usize> CircuitBuilderCurve result } - // Add two points, which are assumed to be non-equal. fn curve_add( &mut self, p1: &AffinePointTarget, From ddd5192489e880428136472423e5e13e6036a70d Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 18 Mar 2022 08:04:40 +0100 Subject: [PATCH 55/56] Move `secret_to_public` to a `ECDSASecretKey` method --- ecdsa/src/curve/ecdsa.rs | 21 ++++++++++++--------- plonky2/src/gadgets/ecdsa.rs | 0 2 files changed, 12 insertions(+), 9 deletions(-) delete mode 100644 plonky2/src/gadgets/ecdsa.rs diff --git a/ecdsa/src/curve/ecdsa.rs b/ecdsa/src/curve/ecdsa.rs index 52262830..bb4ebe4a 100644 --- a/ecdsa/src/curve/ecdsa.rs +++ b/ecdsa/src/curve/ecdsa.rs @@ -1,8 +1,8 @@ +use plonky2_field::field_types::Field; use serde::{Deserialize, Serialize}; use crate::curve::curve_msm::msm_parallel; use crate::curve::curve_types::{base_to_scalar, AffinePoint, Curve, CurveScalar}; -use crate::field::field_types::Field; #[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] pub struct ECDSASignature { @@ -13,13 +13,15 @@ pub struct ECDSASignature { #[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] pub struct ECDSASecretKey(pub C::ScalarField); +impl ECDSASecretKey { + pub fn to_public(&self) -> ECDSAPublicKey { + ECDSAPublicKey((CurveScalar(self.0) * C::GENERATOR_PROJECTIVE).to_affine()) + } +} + #[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] pub struct ECDSAPublicKey(pub AffinePoint); -pub fn secret_to_public(sk: ECDSASecretKey) -> ECDSAPublicKey { - ECDSAPublicKey((CurveScalar(sk.0) * C::GENERATOR_PROJECTIVE).to_affine()) -} - pub fn sign_message(msg: C::ScalarField, sk: ECDSASecretKey) -> ECDSASignature { let (k, rr) = { let mut k = C::ScalarField::rand(); @@ -61,10 +63,11 @@ pub fn verify_message( #[cfg(test)] mod tests { - use crate::curve::ecdsa::{secret_to_public, sign_message, verify_message, ECDSASecretKey}; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::ecdsa::{sign_message, verify_message, ECDSASecretKey}; use crate::curve::secp256k1::Secp256K1; - use crate::field::field_types::Field; - use crate::field::secp256k1_scalar::Secp256K1Scalar; #[test] fn test_ecdsa_native() { @@ -72,7 +75,7 @@ mod tests { let msg = Secp256K1Scalar::rand(); let sk = ECDSASecretKey::(Secp256K1Scalar::rand()); - let pk = secret_to_public(sk); + let pk = sk.to_public(); let sig = sign_message(msg, sk); let result = verify_message(msg, sig, pk); diff --git a/plonky2/src/gadgets/ecdsa.rs b/plonky2/src/gadgets/ecdsa.rs deleted file mode 100644 index e69de29b..00000000 From 9b6582557256c8fdeaf1fded2804d792584933e8 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 28 Mar 2022 14:24:29 +0200 Subject: [PATCH 56/56] Comments --- ecdsa/src/curve/glv.rs | 12 ++++++++---- ecdsa/src/gadgets/biguint.rs | 3 +-- plonky2/src/plonk/circuit_builder.rs | 3 +-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/ecdsa/src/curve/glv.rs b/ecdsa/src/curve/glv.rs index aeeb463e..8859d904 100644 --- a/ecdsa/src/curve/glv.rs +++ b/ecdsa/src/curve/glv.rs @@ -31,6 +31,10 @@ const A2: Secp256K1Scalar = Secp256K1Scalar([6323353552219852760, 14980988506747 const B2: Secp256K1Scalar = Secp256K1Scalar([16747920425669159701, 3496713202691238861, 0, 0]); +/// Algorithm 15.41 in Handbook of Elliptic and Hyperelliptic Curve Cryptography. +/// Decompose a scalar `k` into two small scalars `k1, k2` with `|k1|, |k2| < √p` that satisfy +/// `k1 + s * k2 = k`. +/// Returns `(|k1|, |k2|, k1 < 0, k2 < 0)`. pub fn decompose_secp256k1_scalar( k: Secp256K1Scalar, ) -> (Secp256K1Scalar, Secp256K1Scalar, bool, bool) { @@ -71,12 +75,12 @@ pub fn decompose_secp256k1_scalar( (k1, k2, k1_neg, k2_neg) } +/// See Section 15.2.1 in Handbook of Elliptic and Hyperelliptic Curve Cryptography. +/// GLV scalar multiplication `k * P = k1 * P + k2 * psi(P)`, where `k = k1 + s * k2` is the +/// decomposition computed in `decompose_secp256k1_scalar(k)` and `psi` is the Secp256k1 +/// endomorphism `psi: (x, y) |-> (beta * x, y)` equivalent to scalar multiplication by `s`. pub fn glv_mul(p: ProjectivePoint, k: Secp256K1Scalar) -> ProjectivePoint { let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); - /*let one = Secp256K1Scalar::ONE; - let m1 = if k1_neg { -one } else { one }; - let m2 = if k2_neg { -one } else { one }; - assert!(k1 * m1 + S * k2 * m2 == k);*/ let p_affine = p.to_affine(); let sp = AffinePoint:: { diff --git a/ecdsa/src/gadgets/biguint.rs b/ecdsa/src/gadgets/biguint.rs index 0c08814c..25dab656 100644 --- a/ecdsa/src/gadgets/biguint.rs +++ b/ecdsa/src/gadgets/biguint.rs @@ -272,12 +272,11 @@ pub fn witness_get_biguint_target, F: PrimeField>( witness: &W, bt: BigUintTarget, ) -> BigUint { - let base = BigUint::from(1usize << 32); bt.limbs .into_iter() .rev() .fold(BigUint::zero(), |acc, limb| { - acc * &base + witness.get_target(limb.0).to_canonical_biguint() + (acc << 32) + witness.get_target(limb.0).to_canonical_biguint() }) } diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index 721797bd..8e2f2e10 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -70,7 +70,7 @@ pub struct CircuitBuilder, const D: usize> { marked_targets: Vec>, /// Generators used to generate the witness. - pub generators: Vec>>, + generators: Vec>>, constants_to_targets: HashMap, targets_to_constants: HashMap, @@ -150,7 +150,6 @@ impl, const D: usize> CircuitBuilder { /// generate the final witness (a grid of wire values), these virtual targets will go away. pub fn add_virtual_target(&mut self) -> Target { let index = self.virtual_target_index; - self.virtual_target_index += 1; Target::VirtualTarget { index } }