diff --git a/Cargo.toml b/Cargo.toml index cc070d96..00a4b28f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["field", "insertion", "plonky2", "starky", "system_zero", "util", "waksman"] +members = ["field", "insertion", "plonky2", "starky", "system_zero", "util", "waksman", "ecdsa"] [profile.release] opt-level = 3 diff --git a/ecdsa/Cargo.toml b/ecdsa/Cargo.toml new file mode 100644 index 00000000..a7e1024b --- /dev/null +++ b/ecdsa/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "plonky2_ecdsa" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +plonky2 = { path = "../plonky2" } +plonky2_util = { path = "../util" } +plonky2_field = { path = "../field" } +num = "0.4.0" +itertools = "0.10.0" +rayon = "1.5.1" +serde = { version = "1.0", features = ["derive"] } +anyhow = "1.0.40" +rand = "0.8.4" \ No newline at end of file diff --git a/plonky2/src/curve/curve_adds.rs b/ecdsa/src/curve/curve_adds.rs similarity index 100% rename from plonky2/src/curve/curve_adds.rs rename to ecdsa/src/curve/curve_adds.rs diff --git a/plonky2/src/curve/curve_msm.rs b/ecdsa/src/curve/curve_msm.rs similarity index 100% rename from plonky2/src/curve/curve_msm.rs rename to ecdsa/src/curve/curve_msm.rs diff --git a/plonky2/src/curve/curve_multiplication.rs b/ecdsa/src/curve/curve_multiplication.rs similarity index 99% rename from plonky2/src/curve/curve_multiplication.rs rename to ecdsa/src/curve/curve_multiplication.rs index c6fbbd83..9f2accaf 100644 --- a/plonky2/src/curve/curve_multiplication.rs +++ b/ecdsa/src/curve/curve_multiplication.rs @@ -36,6 +36,7 @@ impl ProjectivePoint { MultiplicationPrecomputation { powers } } + #[must_use] pub fn mul_with_precomputation( &self, scalar: C::ScalarField, diff --git a/plonky2/src/curve/curve_summation.rs b/ecdsa/src/curve/curve_summation.rs similarity index 100% rename from plonky2/src/curve/curve_summation.rs rename to ecdsa/src/curve/curve_summation.rs diff --git a/plonky2/src/curve/curve_types.rs b/ecdsa/src/curve/curve_types.rs similarity index 99% rename from plonky2/src/curve/curve_types.rs rename to ecdsa/src/curve/curve_types.rs index 264120c7..be025e12 100644 --- a/plonky2/src/curve/curve_types.rs +++ b/ecdsa/src/curve/curve_types.rs @@ -78,6 +78,7 @@ impl AffinePoint { affine_points.iter().map(Self::to_projective).collect() } + #[must_use] pub fn double(&self) -> Self { let AffinePoint { x: x1, y: y1, zero } = *self; @@ -187,6 +188,7 @@ impl ProjectivePoint { } // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/doubling/dbl-2007-bl + #[must_use] pub fn double(&self) -> Self { let Self { x, y, z } = *self; if z == C::BaseField::ZERO { @@ -222,6 +224,7 @@ impl ProjectivePoint { .collect() } + #[must_use] pub fn neg(&self) -> Self { Self { x: self.x, diff --git a/plonky2/src/curve/ecdsa.rs b/ecdsa/src/curve/ecdsa.rs similarity index 82% rename from plonky2/src/curve/ecdsa.rs rename to ecdsa/src/curve/ecdsa.rs index 52262830..bb4ebe4a 100644 --- a/plonky2/src/curve/ecdsa.rs +++ b/ecdsa/src/curve/ecdsa.rs @@ -1,8 +1,8 @@ +use plonky2_field::field_types::Field; use serde::{Deserialize, Serialize}; use crate::curve::curve_msm::msm_parallel; use crate::curve::curve_types::{base_to_scalar, AffinePoint, Curve, CurveScalar}; -use crate::field::field_types::Field; #[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] pub struct ECDSASignature { @@ -13,13 +13,15 @@ pub struct ECDSASignature { #[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] pub struct ECDSASecretKey(pub C::ScalarField); +impl ECDSASecretKey { + pub fn to_public(&self) -> ECDSAPublicKey { + ECDSAPublicKey((CurveScalar(self.0) * C::GENERATOR_PROJECTIVE).to_affine()) + } +} + #[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] pub struct ECDSAPublicKey(pub AffinePoint); -pub fn secret_to_public(sk: ECDSASecretKey) -> ECDSAPublicKey { - ECDSAPublicKey((CurveScalar(sk.0) * C::GENERATOR_PROJECTIVE).to_affine()) -} - pub fn sign_message(msg: C::ScalarField, sk: ECDSASecretKey) -> ECDSASignature { let (k, rr) = { let mut k = C::ScalarField::rand(); @@ -61,10 +63,11 @@ pub fn verify_message( #[cfg(test)] mod tests { - use crate::curve::ecdsa::{secret_to_public, sign_message, verify_message, ECDSASecretKey}; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::ecdsa::{sign_message, verify_message, ECDSASecretKey}; use crate::curve::secp256k1::Secp256K1; - use crate::field::field_types::Field; - use crate::field::secp256k1_scalar::Secp256K1Scalar; #[test] fn test_ecdsa_native() { @@ -72,7 +75,7 @@ mod tests { let msg = Secp256K1Scalar::rand(); let sk = ECDSASecretKey::(Secp256K1Scalar::rand()); - let pk = secret_to_public(sk); + let pk = sk.to_public(); let sig = sign_message(msg, sk); let result = verify_message(msg, sig, pk); diff --git a/ecdsa/src/curve/glv.rs b/ecdsa/src/curve/glv.rs new file mode 100644 index 00000000..8859d904 --- /dev/null +++ b/ecdsa/src/curve/glv.rs @@ -0,0 +1,140 @@ +use num::rational::Ratio; +use num::BigUint; +use plonky2_field::field_types::{Field, PrimeField}; +use plonky2_field::secp256k1_base::Secp256K1Base; +use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + +use crate::curve::curve_msm::msm_parallel; +use crate::curve::curve_types::{AffinePoint, ProjectivePoint}; +use crate::curve::secp256k1::Secp256K1; + +pub const GLV_BETA: Secp256K1Base = Secp256K1Base([ + 13923278643952681454, + 11308619431505398165, + 7954561588662645993, + 8856726876819556112, +]); + +pub const GLV_S: Secp256K1Scalar = Secp256K1Scalar([ + 16069571880186789234, + 1310022930574435960, + 11900229862571533402, + 6008836872998760672, +]); + +const A1: Secp256K1Scalar = Secp256K1Scalar([16747920425669159701, 3496713202691238861, 0, 0]); + +const MINUS_B1: Secp256K1Scalar = + Secp256K1Scalar([8022177200260244675, 16448129721693014056, 0, 0]); + +const A2: Secp256K1Scalar = Secp256K1Scalar([6323353552219852760, 1498098850674701302, 1, 0]); + +const B2: Secp256K1Scalar = Secp256K1Scalar([16747920425669159701, 3496713202691238861, 0, 0]); + +/// Algorithm 15.41 in Handbook of Elliptic and Hyperelliptic Curve Cryptography. +/// Decompose a scalar `k` into two small scalars `k1, k2` with `|k1|, |k2| < √p` that satisfy +/// `k1 + s * k2 = k`. +/// Returns `(|k1|, |k2|, k1 < 0, k2 < 0)`. +pub fn decompose_secp256k1_scalar( + k: Secp256K1Scalar, +) -> (Secp256K1Scalar, Secp256K1Scalar, bool, bool) { + let p = Secp256K1Scalar::order(); + let c1_biguint = Ratio::new( + B2.to_canonical_biguint() * k.to_canonical_biguint(), + p.clone(), + ) + .round() + .to_integer(); + let c1 = Secp256K1Scalar::from_biguint(c1_biguint); + let c2_biguint = Ratio::new( + MINUS_B1.to_canonical_biguint() * k.to_canonical_biguint(), + p.clone(), + ) + .round() + .to_integer(); + let c2 = Secp256K1Scalar::from_biguint(c2_biguint); + + let k1_raw = k - c1 * A1 - c2 * A2; + let k2_raw = c1 * MINUS_B1 - c2 * B2; + debug_assert!(k1_raw + GLV_S * k2_raw == k); + + let two = BigUint::from_slice(&[2]); + let k1_neg = k1_raw.to_canonical_biguint() > p.clone() / two.clone(); + let k1 = if k1_neg { + Secp256K1Scalar::from_biguint(p.clone() - k1_raw.to_canonical_biguint()) + } else { + k1_raw + }; + let k2_neg = k2_raw.to_canonical_biguint() > p.clone() / two; + let k2 = if k2_neg { + Secp256K1Scalar::from_biguint(p - k2_raw.to_canonical_biguint()) + } else { + k2_raw + }; + + (k1, k2, k1_neg, k2_neg) +} + +/// See Section 15.2.1 in Handbook of Elliptic and Hyperelliptic Curve Cryptography. +/// GLV scalar multiplication `k * P = k1 * P + k2 * psi(P)`, where `k = k1 + s * k2` is the +/// decomposition computed in `decompose_secp256k1_scalar(k)` and `psi` is the Secp256k1 +/// endomorphism `psi: (x, y) |-> (beta * x, y)` equivalent to scalar multiplication by `s`. +pub fn glv_mul(p: ProjectivePoint, k: Secp256K1Scalar) -> ProjectivePoint { + let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); + + let p_affine = p.to_affine(); + let sp = AffinePoint:: { + x: p_affine.x * GLV_BETA, + y: p_affine.y, + zero: p_affine.zero, + }; + + let first = if k1_neg { p.neg() } else { p }; + let second = if k2_neg { + sp.to_projective().neg() + } else { + sp.to_projective() + }; + + msm_parallel(&[k1, k2], &[first, second], 5) +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::glv::{decompose_secp256k1_scalar, glv_mul, GLV_S}; + use crate::curve::secp256k1::Secp256K1; + + #[test] + fn test_glv_decompose() -> Result<()> { + let k = Secp256K1Scalar::rand(); + let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); + let one = Secp256K1Scalar::ONE; + let m1 = if k1_neg { -one } else { one }; + let m2 = if k2_neg { -one } else { one }; + + assert!(k1 * m1 + GLV_S * k2 * m2 == k); + + Ok(()) + } + + #[test] + fn test_glv_mul() -> Result<()> { + for _ in 0..20 { + let k = Secp256K1Scalar::rand(); + + let p = CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE; + + let kp = CurveScalar(k) * p; + let glv = glv_mul(p, k); + + assert!(kp == glv); + } + + Ok(()) + } +} diff --git a/plonky2/src/curve/mod.rs b/ecdsa/src/curve/mod.rs similarity index 91% rename from plonky2/src/curve/mod.rs rename to ecdsa/src/curve/mod.rs index 8dd6f0d6..1984b0c6 100644 --- a/plonky2/src/curve/mod.rs +++ b/ecdsa/src/curve/mod.rs @@ -4,4 +4,5 @@ pub mod curve_multiplication; pub mod curve_summation; pub mod curve_types; pub mod ecdsa; +pub mod glv; pub mod secp256k1; diff --git a/plonky2/src/curve/secp256k1.rs b/ecdsa/src/curve/secp256k1.rs similarity index 100% rename from plonky2/src/curve/secp256k1.rs rename to ecdsa/src/curve/secp256k1.rs diff --git a/plonky2/src/gadgets/biguint.rs b/ecdsa/src/gadgets/biguint.rs similarity index 69% rename from plonky2/src/gadgets/biguint.rs rename to ecdsa/src/gadgets/biguint.rs index c9ad7280..25dab656 100644 --- a/plonky2/src/gadgets/biguint.rs +++ b/ecdsa/src/gadgets/biguint.rs @@ -1,14 +1,14 @@ use std::marker::PhantomData; use num::{BigUint, Integer, Zero}; +use plonky2::gadgets::arithmetic_u32::U32Target; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::iop::witness::{PartitionWitness, Witness}; +use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2_field::extension_field::Extendable; - -use crate::gadgets::arithmetic_u32::U32Target; -use crate::hash::hash_types::RichField; -use crate::iop::generator::{GeneratedValues, SimpleGenerator}; -use crate::iop::target::{BoolTarget, Target}; -use crate::iop::witness::{PartitionWitness, Witness}; -use crate::plonk::circuit_builder::CircuitBuilder; +use plonky2_field::field_types::PrimeField; #[derive(Clone, Debug)] pub struct BigUintTarget { @@ -25,19 +25,67 @@ impl BigUintTarget { } } -impl, const D: usize> CircuitBuilder { - pub fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget { +pub trait CircuitBuilderBiguint, const D: usize> { + fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget; + + fn zero_biguint(&mut self) -> BigUintTarget; + + fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget); + + fn pad_biguints( + &mut self, + a: &BigUintTarget, + b: &BigUintTarget, + ) -> (BigUintTarget, BigUintTarget); + + fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget; + + fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget; + + /// Add two `BigUintTarget`s. + fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; + + /// Subtract two `BigUintTarget`s. We assume that the first is larger than the second. + fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; + + fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; + + fn mul_biguint_by_bool(&mut self, a: &BigUintTarget, b: BoolTarget) -> BigUintTarget; + + /// Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function). + fn mul_add_biguint( + &mut self, + x: &BigUintTarget, + y: &BigUintTarget, + z: &BigUintTarget, + ) -> BigUintTarget; + + fn div_rem_biguint( + &mut self, + a: &BigUintTarget, + b: &BigUintTarget, + ) -> (BigUintTarget, BigUintTarget); + + fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; + + fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; +} + +impl, const D: usize> CircuitBuilderBiguint + for CircuitBuilder +{ + fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget { let limb_values = value.to_u32_digits(); let limbs = limb_values.iter().map(|&l| self.constant_u32(l)).collect(); BigUintTarget { limbs } } - pub fn zero_biguint(&mut self) -> BigUintTarget { + fn zero_biguint(&mut self) -> BigUintTarget { self.constant_biguint(&BigUint::zero()) } - pub fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) { + fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) { let min_limbs = lhs.num_limbs().min(rhs.num_limbs()); for i in 0..min_limbs { self.connect_u32(lhs.get_limb(i), rhs.get_limb(i)); @@ -51,7 +99,7 @@ impl, const D: usize> CircuitBuilder { } } - pub fn pad_biguints( + fn pad_biguints( &mut self, a: &BigUintTarget, b: &BigUintTarget, @@ -73,20 +121,19 @@ impl, const D: usize> CircuitBuilder { } } - pub fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget { + fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget { let (a, b) = self.pad_biguints(a, b); self.list_le_u32(a.limbs, b.limbs) } - pub fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget { + fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget { let limbs = self.add_virtual_u32_targets(num_limbs); BigUintTarget { limbs } } - // Add two `BigUintTarget`s. - pub fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let num_limbs = a.num_limbs().max(b.num_limbs()); let mut combined_limbs = vec![]; @@ -110,8 +157,7 @@ impl, const D: usize> CircuitBuilder { } } - // Subtract two `BigUintTarget`s. We assume that the first is larger than the second. - pub fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let (a, b) = self.pad_biguints(a, b); let num_limbs = a.limbs.len(); @@ -130,7 +176,7 @@ impl, const D: usize> CircuitBuilder { } } - pub fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let total_limbs = a.limbs.len() + b.limbs.len(); let mut to_add = vec![vec![]; total_limbs]; @@ -156,7 +202,7 @@ impl, const D: usize> CircuitBuilder { } } - pub fn mul_biguint_by_bool(&mut self, a: &BigUintTarget, b: BoolTarget) -> BigUintTarget { + fn mul_biguint_by_bool(&mut self, a: &BigUintTarget, b: BoolTarget) -> BigUintTarget { let t = b.target; BigUintTarget { @@ -168,8 +214,7 @@ impl, const D: usize> CircuitBuilder { } } - // Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function). - pub fn mul_add_biguint( + fn mul_add_biguint( &mut self, x: &BigUintTarget, y: &BigUintTarget, @@ -179,7 +224,7 @@ impl, const D: usize> CircuitBuilder { self.add_biguint(&prod, z) } - pub fn div_rem_biguint( + fn div_rem_biguint( &mut self, a: &BigUintTarget, b: &BigUintTarget, @@ -212,17 +257,55 @@ impl, const D: usize> CircuitBuilder { (div, rem) } - pub fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let (div, _rem) = self.div_rem_biguint(a, b); div } - pub fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let (_div, rem) = self.div_rem_biguint(a, b); rem } } +pub fn witness_get_biguint_target, F: PrimeField>( + witness: &W, + bt: BigUintTarget, +) -> BigUint { + bt.limbs + .into_iter() + .rev() + .fold(BigUint::zero(), |acc, limb| { + (acc << 32) + witness.get_target(limb.0).to_canonical_biguint() + }) +} + +pub fn witness_set_biguint_target, F: PrimeField>( + witness: &mut W, + target: &BigUintTarget, + value: &BigUint, +) { + let mut limbs = value.to_u32_digits(); + assert!(target.num_limbs() >= limbs.len()); + limbs.resize(target.num_limbs(), 0); + for i in 0..target.num_limbs() { + witness.set_u32_target(target.limbs[i], limbs[i]); + } +} + +pub fn buffer_set_biguint_target( + buffer: &mut GeneratedValues, + target: &BigUintTarget, + value: &BigUint, +) { + let mut limbs = value.to_u32_digits(); + assert!(target.num_limbs() >= limbs.len()); + limbs.resize(target.num_limbs(), 0); + for i in 0..target.num_limbs() { + buffer.set_u32_target(target.get_limb(i), limbs[i]); + } +} + #[derive(Debug)] struct BigUintDivRemGenerator, const D: usize> { a: BigUintTarget, @@ -245,12 +328,12 @@ impl, const D: usize> SimpleGenerator } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = witness.get_biguint_target(self.a.clone()); - let b = witness.get_biguint_target(self.b.clone()); + let a = witness_get_biguint_target(witness, self.a.clone()); + let b = witness_get_biguint_target(witness, self.b.clone()); let (div, rem) = a.div_rem(&b); - out_buffer.set_biguint_target(self.div.clone(), div); - out_buffer.set_biguint_target(self.rem.clone(), rem); + buffer_set_biguint_target(out_buffer, &self.div, &div); + buffer_set_biguint_target(out_buffer, &self.rem, &rem); } } @@ -258,14 +341,14 @@ impl, const D: usize> SimpleGenerator mod tests { use anyhow::Result; use num::{BigUint, FromPrimitive, Integer}; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2::{ + iop::witness::PartialWitness, + plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig}, + }; use rand::Rng; - use crate::iop::witness::Witness; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::{ - iop::witness::PartialWitness, - plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig, verifier::verify}, - }; + use crate::gadgets::biguint::{witness_set_biguint_target, CircuitBuilderBiguint}; #[test] fn test_biguint_add() -> Result<()> { @@ -288,13 +371,13 @@ mod tests { let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); builder.connect_biguint(&z, &expected_z); - pw.set_biguint_target(&x, &x_value); - pw.set_biguint_target(&y, &y_value); - pw.set_biguint_target(&expected_z, &expected_z_value); + witness_set_biguint_target(&mut pw, &x, &x_value); + witness_set_biguint_target(&mut pw, &y, &y_value); + witness_set_biguint_target(&mut pw, &expected_z, &expected_z_value); let data = builder.build::(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) + data.verify(proof) } #[test] @@ -324,7 +407,7 @@ mod tests { let data = builder.build::(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) + data.verify(proof) } #[test] @@ -348,13 +431,13 @@ mod tests { let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); builder.connect_biguint(&z, &expected_z); - pw.set_biguint_target(&x, &x_value); - pw.set_biguint_target(&y, &y_value); - pw.set_biguint_target(&expected_z, &expected_z_value); + witness_set_biguint_target(&mut pw, &x, &x_value); + witness_set_biguint_target(&mut pw, &y, &y_value); + witness_set_biguint_target(&mut pw, &expected_z, &expected_z_value); let data = builder.build::(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) + data.verify(proof) } #[test] @@ -380,7 +463,7 @@ mod tests { let data = builder.build::(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) + data.verify(proof) } #[test] @@ -413,6 +496,6 @@ mod tests { let data = builder.build::(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) + data.verify(proof) } } diff --git a/plonky2/src/gadgets/curve.rs b/ecdsa/src/gadgets/curve.rs similarity index 68% rename from plonky2/src/gadgets/curve.rs rename to ecdsa/src/gadgets/curve.rs index 8c182345..e511973b 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/ecdsa/src/gadgets/curve.rs @@ -1,10 +1,11 @@ +use plonky2::hash::hash_types::RichField; +use plonky2::iop::target::BoolTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2_field::extension_field::Extendable; use plonky2_field::field_types::Field; use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; -use crate::gadgets::nonnative::NonNativeTarget; -use crate::hash::hash_types::RichField; -use crate::plonk::circuit_builder::CircuitBuilder; +use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; /// A Target representing an affine point on the curve `C`. We use incomplete arithmetic for efficiency, /// so we assume these points are not zero. @@ -20,11 +21,60 @@ impl AffinePointTarget { } } -impl, const D: usize> CircuitBuilder { - pub fn constant_affine_point( +pub trait CircuitBuilderCurve, const D: usize> { + fn constant_affine_point(&mut self, point: AffinePoint) -> AffinePointTarget; + + fn connect_affine_point( &mut self, - point: AffinePoint, - ) -> AffinePointTarget { + lhs: &AffinePointTarget, + rhs: &AffinePointTarget, + ); + + fn add_virtual_affine_point_target(&mut self) -> AffinePointTarget; + + fn curve_assert_valid(&mut self, p: &AffinePointTarget); + + fn curve_neg(&mut self, p: &AffinePointTarget) -> AffinePointTarget; + + fn curve_conditional_neg( + &mut self, + p: &AffinePointTarget, + b: BoolTarget, + ) -> AffinePointTarget; + + fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget; + + fn curve_repeated_double( + &mut self, + p: &AffinePointTarget, + n: usize, + ) -> AffinePointTarget; + + /// Add two points, which are assumed to be non-equal. + fn curve_add( + &mut self, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + ) -> AffinePointTarget; + + fn curve_conditional_add( + &mut self, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + b: BoolTarget, + ) -> AffinePointTarget; + + fn curve_scalar_mul( + &mut self, + p: &AffinePointTarget, + n: &NonNativeTarget, + ) -> AffinePointTarget; +} + +impl, const D: usize> CircuitBuilderCurve + for CircuitBuilder +{ + fn constant_affine_point(&mut self, point: AffinePoint) -> AffinePointTarget { debug_assert!(!point.zero); AffinePointTarget { x: self.constant_nonnative(point.x), @@ -32,7 +82,7 @@ impl, const D: usize> CircuitBuilder { } } - pub fn connect_affine_point( + fn connect_affine_point( &mut self, lhs: &AffinePointTarget, rhs: &AffinePointTarget, @@ -41,14 +91,14 @@ impl, const D: usize> CircuitBuilder { self.connect_nonnative(&lhs.y, &rhs.y); } - pub fn add_virtual_affine_point_target(&mut self) -> AffinePointTarget { + fn add_virtual_affine_point_target(&mut self) -> AffinePointTarget { let x = self.add_virtual_nonnative_target(); let y = self.add_virtual_nonnative_target(); AffinePointTarget { x, y } } - pub fn curve_assert_valid(&mut self, p: &AffinePointTarget) { + fn curve_assert_valid(&mut self, p: &AffinePointTarget) { let a = self.constant_nonnative(C::A); let b = self.constant_nonnative(C::B); @@ -62,7 +112,7 @@ impl, const D: usize> CircuitBuilder { self.connect_nonnative(&y_squared, &rhs); } - pub fn curve_neg(&mut self, p: &AffinePointTarget) -> AffinePointTarget { + fn curve_neg(&mut self, p: &AffinePointTarget) -> AffinePointTarget { let neg_y = self.neg_nonnative(&p.y); AffinePointTarget { x: p.x.clone(), @@ -70,7 +120,18 @@ impl, const D: usize> CircuitBuilder { } } - pub fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget { + fn curve_conditional_neg( + &mut self, + p: &AffinePointTarget, + b: BoolTarget, + ) -> AffinePointTarget { + AffinePointTarget { + x: p.x.clone(), + y: self.nonnative_conditional_neg(&p.y, b), + } + } + + fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget { let AffinePointTarget { x, y } = p; let double_y = self.add_nonnative(y, y); let inv_double_y = self.inv_nonnative(&double_y); @@ -94,8 +155,21 @@ impl, const D: usize> CircuitBuilder { AffinePointTarget { x: x3, y: y3 } } - // Add two points, which are assumed to be non-equal. - pub fn curve_add( + fn curve_repeated_double( + &mut self, + p: &AffinePointTarget, + n: usize, + ) -> AffinePointTarget { + let mut result = p.clone(); + + for _ in 0..n { + result = self.curve_double(&result); + } + + result + } + + fn curve_add( &mut self, p1: &AffinePointTarget, p2: &AffinePointTarget, @@ -117,7 +191,26 @@ impl, const D: usize> CircuitBuilder { AffinePointTarget { x: x3, y: y3 } } - pub fn curve_scalar_mul( + fn curve_conditional_add( + &mut self, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + b: BoolTarget, + ) -> AffinePointTarget { + let not_b = self.not(b); + let sum = self.curve_add(p1, p2); + let x_if_true = self.mul_nonnative_by_bool(&sum.x, b); + let y_if_true = self.mul_nonnative_by_bool(&sum.y, b); + let x_if_false = self.mul_nonnative_by_bool(&p1.x, not_b); + let y_if_false = self.mul_nonnative_by_bool(&p1.y, not_b); + + let x = self.add_nonnative(&x_if_true, &x_if_false); + let y = self.add_nonnative(&y_if_true, &y_if_false); + + AffinePointTarget { x, y } + } + + fn curve_scalar_mul( &mut self, p: &AffinePointTarget, n: &NonNativeTarget, @@ -164,17 +257,18 @@ mod tests { use std::ops::Neg; use anyhow::Result; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use plonky2_field::field_types::Field; use plonky2_field::secp256k1_base::Secp256K1Base; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; use crate::curve::secp256k1::Secp256K1; - use crate::iop::witness::PartialWitness; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::plonk::verifier::verify; + use crate::gadgets::curve::CircuitBuilderCurve; + use crate::gadgets::nonnative::CircuitBuilderNonNative; #[test] fn test_curve_point_is_valid() -> Result<()> { @@ -197,7 +291,7 @@ mod tests { let data = builder.build::(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) + data.verify(proof) } #[test] @@ -225,7 +319,7 @@ mod tests { let data = builder.build::(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common).unwrap(); + data.verify(proof).unwrap() } #[test] @@ -262,7 +356,7 @@ mod tests { let data = builder.build::(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) + data.verify(proof) } #[test] @@ -292,7 +386,39 @@ mod tests { let data = builder.build::(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) + data.verify(proof) + } + + #[test] + fn test_curve_conditional_add() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let double_g = g.double(); + let g_plus_2g = (g + double_g).to_affine(); + let g_plus_2g_expected = builder.constant_affine_point(g_plus_2g); + + let g_expected = builder.constant_affine_point(g); + let double_g_target = builder.curve_double(&g_expected); + let t = builder._true(); + let f = builder._false(); + let g_plus_2g_actual = builder.curve_conditional_add(&g_expected, &double_g_target, t); + let g_actual = builder.curve_conditional_add(&g_expected, &double_g_target, f); + + builder.connect_affine_point(&g_plus_2g_expected, &g_plus_2g_actual); + builder.connect_affine_point(&g_expected, &g_actual); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) } #[test] @@ -307,7 +433,7 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let g = Secp256K1::GENERATOR_AFFINE; + let g = Secp256K1::GENERATOR_PROJECTIVE.to_affine(); let five = Secp256K1Scalar::from_canonical_usize(5); let neg_five = five.neg(); let neg_five_scalar = CurveScalar::(neg_five); @@ -325,10 +451,11 @@ mod tests { let data = builder.build::(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) + data.verify(proof) } #[test] + #[ignore] fn test_curve_random() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; @@ -351,6 +478,6 @@ mod tests { let data = builder.build::(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) + data.verify(proof) } } diff --git a/ecdsa/src/gadgets/curve_fixed_base.rs b/ecdsa/src/gadgets/curve_fixed_base.rs new file mode 100644 index 00000000..e64ec134 --- /dev/null +++ b/ecdsa/src/gadgets/curve_fixed_base.rs @@ -0,0 +1,113 @@ +use num::BigUint; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::keccak::KeccakHash; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::config::{GenericHashOut, Hasher}; +use plonky2_field::extension_field::Extendable; +use plonky2_field::field_types::Field; + +use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; +use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; +use crate::gadgets::curve_windowed_mul::CircuitBuilderWindowedMul; +use crate::gadgets::nonnative::NonNativeTarget; +use crate::gadgets::split_nonnative::CircuitBuilderSplit; + +/// Compute windowed fixed-base scalar multiplication, using a 4-bit window. +pub fn fixed_base_curve_mul_circuit, const D: usize>( + builder: &mut CircuitBuilder, + base: AffinePoint, + scalar: &NonNativeTarget, +) -> AffinePointTarget { + // Holds `(16^i) * base` for `i=0..scalar.value.limbs.len() * 8`. + let scaled_base = (0..scalar.value.limbs.len() * 8).scan(base, |acc, _| { + let tmp = *acc; + for _ in 0..4 { + *acc = acc.double(); + } + Some(tmp) + }); + + let limbs = builder.split_nonnative_to_4_bit_limbs(scalar); + + let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); + let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( + &GenericHashOut::::to_bytes(&hash_0), + )); + let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); + + let zero = builder.zero(); + let mut result = builder.constant_affine_point(rando); + // `s * P = sum s_i * P_i` with `P_i = (16^i) * P` and `s = sum s_i * (16^i)`. + for (limb, point) in limbs.into_iter().zip(scaled_base) { + // `muls_point[t] = t * P_i` for `t=0..16`. + let muls_point = (0..16) + .scan(AffinePoint::ZERO, |acc, _| { + let tmp = *acc; + *acc = (point + *acc).to_affine(); + Some(tmp) + }) + .map(|p| builder.constant_affine_point(p)) + .collect::>(); + let is_zero = builder.is_equal(limb, zero); + let should_add = builder.not(is_zero); + // `r = s_i * P_i` + let r = builder.random_access_curve_points(limb, muls_point); + result = builder.curve_conditional_add(&result, &r, should_add); + } + + let to_add = builder.constant_affine_point(-rando); + builder.curve_add(&result, &to_add) +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2_field::field_types::Field; + use plonky2_field::field_types::PrimeField; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::secp256k1::Secp256K1; + use crate::gadgets::biguint::witness_set_biguint_target; + use crate::gadgets::curve::CircuitBuilderCurve; + use crate::gadgets::curve_fixed_base::fixed_base_curve_mul_circuit; + use crate::gadgets::nonnative::CircuitBuilderNonNative; + + #[test] + #[ignore] + fn test_fixed_base() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let mut pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let n = Secp256K1Scalar::rand(); + + let res = (CurveScalar(n) * g.to_projective()).to_affine(); + let res_expected = builder.constant_affine_point(res); + builder.curve_assert_valid(&res_expected); + + let n_target = builder.add_virtual_nonnative_target::(); + witness_set_biguint_target(&mut pw, &n_target.value, &n.to_canonical_biguint()); + + let res_target = fixed_base_curve_mul_circuit(&mut builder, g, &n_target); + builder.curve_assert_valid(&res_target); + + builder.connect_affine_point(&res_target, &res_expected); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } +} diff --git a/ecdsa/src/gadgets/curve_msm.rs b/ecdsa/src/gadgets/curve_msm.rs new file mode 100644 index 00000000..c57cecd2 --- /dev/null +++ b/ecdsa/src/gadgets/curve_msm.rs @@ -0,0 +1,136 @@ +use num::BigUint; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::keccak::KeccakHash; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::config::{GenericHashOut, Hasher}; +use plonky2_field::extension_field::Extendable; +use plonky2_field::field_types::Field; + +use crate::curve::curve_types::{Curve, CurveScalar}; +use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; +use crate::gadgets::curve_windowed_mul::CircuitBuilderWindowedMul; +use crate::gadgets::nonnative::NonNativeTarget; +use crate::gadgets::split_nonnative::CircuitBuilderSplit; + +/// Computes `n*p + m*q` using windowed MSM, with a 2-bit window. +/// See Algorithm 9.23 in Handbook of Elliptic and Hyperelliptic Curve Cryptography for a +/// description. +/// Note: Doesn't work if `p == q`. +pub fn curve_msm_circuit, const D: usize>( + builder: &mut CircuitBuilder, + p: &AffinePointTarget, + q: &AffinePointTarget, + n: &NonNativeTarget, + m: &NonNativeTarget, +) -> AffinePointTarget { + let limbs_n = builder.split_nonnative_to_2_bit_limbs(n); + let limbs_m = builder.split_nonnative_to_2_bit_limbs(m); + assert_eq!(limbs_n.len(), limbs_m.len()); + let num_limbs = limbs_n.len(); + + let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); + let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( + &GenericHashOut::::to_bytes(&hash_0), + )); + let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); + let rando_t = builder.constant_affine_point(rando); + let neg_rando = builder.constant_affine_point(-rando); + + // Precomputes `precomputation[i + 4*j] = i*p + j*q` for `i,j=0..4`. + let mut precomputation = vec![p.clone(); 16]; + let mut cur_p = rando_t.clone(); + let mut cur_q = rando_t.clone(); + for i in 0..4 { + precomputation[i] = cur_p.clone(); + precomputation[4 * i] = cur_q.clone(); + cur_p = builder.curve_add(&cur_p, p); + cur_q = builder.curve_add(&cur_q, q); + } + for i in 1..4 { + precomputation[i] = builder.curve_add(&precomputation[i], &neg_rando); + precomputation[4 * i] = builder.curve_add(&precomputation[4 * i], &neg_rando); + } + for i in 1..4 { + for j in 1..4 { + precomputation[i + 4 * j] = + builder.curve_add(&precomputation[i], &precomputation[4 * j]); + } + } + + let four = builder.constant(F::from_canonical_usize(4)); + + let zero = builder.zero(); + let mut result = rando_t; + for (limb_n, limb_m) in limbs_n.into_iter().zip(limbs_m).rev() { + result = builder.curve_repeated_double(&result, 2); + let index = builder.mul_add(four, limb_m, limb_n); + let r = builder.random_access_curve_points(index, precomputation.clone()); + let is_zero = builder.is_equal(index, zero); + let should_add = builder.not(is_zero); + result = builder.curve_conditional_add(&result, &r, should_add); + } + let starting_point_multiplied = (0..2 * num_limbs).fold(rando, |acc, _| acc.double()); + let to_add = builder.constant_affine_point(-starting_point_multiplied); + result = builder.curve_add(&result, &to_add); + + result +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::secp256k1::Secp256K1; + use crate::gadgets::curve::CircuitBuilderCurve; + use crate::gadgets::curve_msm::curve_msm_circuit; + use crate::gadgets::nonnative::CircuitBuilderNonNative; + + #[test] + #[ignore] + fn test_curve_msm() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let p = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let q = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let n = Secp256K1Scalar::rand(); + let m = Secp256K1Scalar::rand(); + + let res = + (CurveScalar(n) * p.to_projective() + CurveScalar(m) * q.to_projective()).to_affine(); + let res_expected = builder.constant_affine_point(res); + builder.curve_assert_valid(&res_expected); + + let p_target = builder.constant_affine_point(p); + let q_target = builder.constant_affine_point(q); + let n_target = builder.constant_nonnative(n); + let m_target = builder.constant_nonnative(m); + + let res_target = + curve_msm_circuit(&mut builder, &p_target, &q_target, &n_target, &m_target); + builder.curve_assert_valid(&res_target); + + builder.connect_affine_point(&res_target, &res_expected); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } +} diff --git a/ecdsa/src/gadgets/curve_windowed_mul.rs b/ecdsa/src/gadgets/curve_windowed_mul.rs new file mode 100644 index 00000000..05c2c58a --- /dev/null +++ b/ecdsa/src/gadgets/curve_windowed_mul.rs @@ -0,0 +1,256 @@ +use std::marker::PhantomData; + +use num::BigUint; +use plonky2::gadgets::arithmetic_u32::U32Target; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::keccak::KeccakHash; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::config::{GenericHashOut, Hasher}; +use plonky2_field::extension_field::Extendable; +use plonky2_field::field_types::Field; + +use crate::curve::curve_types::{Curve, CurveScalar}; +use crate::gadgets::biguint::BigUintTarget; +use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; +use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; +use crate::gadgets::split_nonnative::CircuitBuilderSplit; + +const WINDOW_SIZE: usize = 4; + +pub trait CircuitBuilderWindowedMul, const D: usize> { + fn precompute_window( + &mut self, + p: &AffinePointTarget, + ) -> Vec>; + + fn random_access_curve_points( + &mut self, + access_index: Target, + v: Vec>, + ) -> AffinePointTarget; + + fn if_affine_point( + &mut self, + b: BoolTarget, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + ) -> AffinePointTarget; + + fn curve_scalar_mul_windowed( + &mut self, + p: &AffinePointTarget, + n: &NonNativeTarget, + ) -> AffinePointTarget; +} + +impl, const D: usize> CircuitBuilderWindowedMul + for CircuitBuilder +{ + fn precompute_window( + &mut self, + p: &AffinePointTarget, + ) -> Vec> { + let g = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine(); + let neg = { + let mut neg = g; + neg.y = -neg.y; + self.constant_affine_point(neg) + }; + + let mut multiples = vec![self.constant_affine_point(g)]; + for i in 1..1 << WINDOW_SIZE { + multiples.push(self.curve_add(p, &multiples[i - 1])); + } + for i in 1..1 << WINDOW_SIZE { + multiples[i] = self.curve_add(&neg, &multiples[i]); + } + multiples + } + + fn random_access_curve_points( + &mut self, + access_index: Target, + v: Vec>, + ) -> AffinePointTarget { + let num_limbs = C::BaseField::BITS / 32; + let zero = self.zero_u32(); + let x_limbs: Vec> = (0..num_limbs) + .map(|i| { + v.iter() + .map(|p| p.x.value.limbs.get(i).unwrap_or(&zero).0) + .collect() + }) + .collect(); + let y_limbs: Vec> = (0..num_limbs) + .map(|i| { + v.iter() + .map(|p| p.y.value.limbs.get(i).unwrap_or(&zero).0) + .collect() + }) + .collect(); + + let selected_x_limbs: Vec<_> = x_limbs + .iter() + .map(|limbs| U32Target(self.random_access(access_index, limbs.clone()))) + .collect(); + let selected_y_limbs: Vec<_> = y_limbs + .iter() + .map(|limbs| U32Target(self.random_access(access_index, limbs.clone()))) + .collect(); + + let x = NonNativeTarget { + value: BigUintTarget { + limbs: selected_x_limbs, + }, + _phantom: PhantomData, + }; + let y = NonNativeTarget { + value: BigUintTarget { + limbs: selected_y_limbs, + }, + _phantom: PhantomData, + }; + AffinePointTarget { x, y } + } + + fn if_affine_point( + &mut self, + b: BoolTarget, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + ) -> AffinePointTarget { + let new_x = self.if_nonnative(b, &p1.x, &p2.x); + let new_y = self.if_nonnative(b, &p1.y, &p2.y); + AffinePointTarget { x: new_x, y: new_y } + } + + fn curve_scalar_mul_windowed( + &mut self, + p: &AffinePointTarget, + n: &NonNativeTarget, + ) -> AffinePointTarget { + let hash_0 = KeccakHash::<25>::hash_no_pad(&[F::ZERO]); + let hash_0_scalar = C::ScalarField::from_biguint(BigUint::from_bytes_le( + &GenericHashOut::::to_bytes(&hash_0), + )); + let starting_point = CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE; + let starting_point_multiplied = { + let mut cur = starting_point; + for _ in 0..C::ScalarField::BITS { + cur = cur.double(); + } + cur + }; + + let mut result = self.constant_affine_point(starting_point.to_affine()); + + let precomputation = self.precompute_window(p); + let zero = self.zero(); + + let windows = self.split_nonnative_to_4_bit_limbs(n); + for i in (0..windows.len()).rev() { + result = self.curve_repeated_double(&result, WINDOW_SIZE); + let window = windows[i]; + + let to_add = self.random_access_curve_points(window, precomputation.clone()); + let is_zero = self.is_equal(window, zero); + let should_add = self.not(is_zero); + result = self.curve_conditional_add(&result, &to_add, should_add); + } + + let to_subtract = self.constant_affine_point(starting_point_multiplied.to_affine()); + let to_add = self.curve_neg(&to_subtract); + result = self.curve_add(&result, &to_add); + + result + } +} + +#[cfg(test)] +mod tests { + use std::ops::Neg; + + use anyhow::Result; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + use rand::Rng; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::secp256k1::Secp256K1; + use crate::gadgets::curve::CircuitBuilderCurve; + use crate::gadgets::curve_windowed_mul::CircuitBuilderWindowedMul; + use crate::gadgets::nonnative::CircuitBuilderNonNative; + + #[test] + fn test_random_access_curve_points() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let num_points = 16; + let points: Vec<_> = (0..num_points) + .map(|_| { + let g = (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE) + .to_affine(); + builder.constant_affine_point(g) + }) + .collect(); + + let mut rng = rand::thread_rng(); + let access_index = rng.gen::() % num_points; + + let access_index_target = builder.constant(F::from_canonical_usize(access_index)); + let selected = builder.random_access_curve_points(access_index_target, points.clone()); + let expected = points[access_index].clone(); + builder.connect_affine_point(&selected, &expected); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } + + #[test] + #[ignore] + fn test_curve_windowed_mul() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let five = Secp256K1Scalar::from_canonical_usize(5); + let neg_five = five.neg(); + let neg_five_scalar = CurveScalar::(neg_five); + let neg_five_g = (neg_five_scalar * g.to_projective()).to_affine(); + let neg_five_g_expected = builder.constant_affine_point(neg_five_g); + builder.curve_assert_valid(&neg_five_g_expected); + + let g_target = builder.constant_affine_point(g); + let neg_five_target = builder.constant_nonnative(neg_five); + let neg_five_g_actual = builder.curve_scalar_mul_windowed(&g_target, &neg_five_target); + builder.curve_assert_valid(&neg_five_g_actual); + + builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } +} diff --git a/ecdsa/src/gadgets/ecdsa.rs b/ecdsa/src/gadgets/ecdsa.rs new file mode 100644 index 00000000..afd79c61 --- /dev/null +++ b/ecdsa/src/gadgets/ecdsa.rs @@ -0,0 +1,117 @@ +use std::marker::PhantomData; + +use plonky2::hash::hash_types::RichField; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2_field::extension_field::Extendable; +use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + +use crate::curve::curve_types::Curve; +use crate::curve::secp256k1::Secp256K1; +use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; +use crate::gadgets::curve_fixed_base::fixed_base_curve_mul_circuit; +use crate::gadgets::glv::CircuitBuilderGlv; +use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; + +#[derive(Clone, Debug)] +pub struct ECDSASecretKeyTarget(NonNativeTarget); + +#[derive(Clone, Debug)] +pub struct ECDSAPublicKeyTarget(AffinePointTarget); + +#[derive(Clone, Debug)] +pub struct ECDSASignatureTarget { + pub r: NonNativeTarget, + pub s: NonNativeTarget, +} + +pub fn verify_message_circuit, const D: usize>( + builder: &mut CircuitBuilder, + msg: NonNativeTarget, + sig: ECDSASignatureTarget, + pk: ECDSAPublicKeyTarget, +) { + let ECDSASignatureTarget { r, s } = sig; + + builder.curve_assert_valid(&pk.0); + + let c = builder.inv_nonnative(&s); + let u1 = builder.mul_nonnative(&msg, &c); + let u2 = builder.mul_nonnative(&r, &c); + + let point1 = fixed_base_curve_mul_circuit(builder, Secp256K1::GENERATOR_AFFINE, &u1); + let point2 = builder.glv_mul(&pk.0, &u2); + let point = builder.curve_add(&point1, &point2); + + let x = NonNativeTarget:: { + value: point.x.value, + _phantom: PhantomData, + }; + builder.connect_nonnative(&r, &x); +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use super::{ECDSAPublicKeyTarget, ECDSASignatureTarget}; + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::ecdsa::{sign_message, ECDSAPublicKey, ECDSASecretKey, ECDSASignature}; + use crate::curve::secp256k1::Secp256K1; + use crate::gadgets::curve::CircuitBuilderCurve; + use crate::gadgets::ecdsa::verify_message_circuit; + use crate::gadgets::nonnative::CircuitBuilderNonNative; + + fn test_ecdsa_circuit_with_config(config: CircuitConfig) -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + type Curve = Secp256K1; + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let msg = Secp256K1Scalar::rand(); + let msg_target = builder.constant_nonnative(msg); + + let sk = ECDSASecretKey::(Secp256K1Scalar::rand()); + let pk = ECDSAPublicKey((CurveScalar(sk.0) * Curve::GENERATOR_PROJECTIVE).to_affine()); + + let pk_target = ECDSAPublicKeyTarget(builder.constant_affine_point(pk.0)); + + let sig = sign_message(msg, sk); + + let ECDSASignature { r, s } = sig; + let r_target = builder.constant_nonnative(r); + let s_target = builder.constant_nonnative(s); + let sig_target = ECDSASignatureTarget { + r: r_target, + s: s_target, + }; + + verify_message_circuit(&mut builder, msg_target, sig_target, pk_target); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + #[ignore] + fn test_ecdsa_circuit_narrow() -> Result<()> { + test_ecdsa_circuit_with_config(CircuitConfig::standard_ecc_config()) + } + + #[test] + #[ignore] + fn test_ecdsa_circuit_wide() -> Result<()> { + test_ecdsa_circuit_with_config(CircuitConfig::wide_ecc_config()) + } +} diff --git a/ecdsa/src/gadgets/glv.rs b/ecdsa/src/gadgets/glv.rs new file mode 100644 index 00000000..c567e5ab --- /dev/null +++ b/ecdsa/src/gadgets/glv.rs @@ -0,0 +1,180 @@ +use std::marker::PhantomData; + +use plonky2::hash::hash_types::RichField; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::iop::witness::PartitionWitness; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2_field::extension_field::Extendable; +use plonky2_field::field_types::{Field, PrimeField}; +use plonky2_field::secp256k1_base::Secp256K1Base; +use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + +use crate::curve::glv::{decompose_secp256k1_scalar, GLV_BETA, GLV_S}; +use crate::curve::secp256k1::Secp256K1; +use crate::gadgets::biguint::{buffer_set_biguint_target, witness_get_biguint_target}; +use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; +use crate::gadgets::curve_msm::curve_msm_circuit; +use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; + +pub trait CircuitBuilderGlv, const D: usize> { + fn secp256k1_glv_beta(&mut self) -> NonNativeTarget; + + fn decompose_secp256k1_scalar( + &mut self, + k: &NonNativeTarget, + ) -> ( + NonNativeTarget, + NonNativeTarget, + BoolTarget, + BoolTarget, + ); + + fn glv_mul( + &mut self, + p: &AffinePointTarget, + k: &NonNativeTarget, + ) -> AffinePointTarget; +} + +impl, const D: usize> CircuitBuilderGlv + for CircuitBuilder +{ + fn secp256k1_glv_beta(&mut self) -> NonNativeTarget { + self.constant_nonnative(GLV_BETA) + } + + fn decompose_secp256k1_scalar( + &mut self, + k: &NonNativeTarget, + ) -> ( + NonNativeTarget, + NonNativeTarget, + BoolTarget, + BoolTarget, + ) { + let k1 = self.add_virtual_nonnative_target_sized::(4); + let k2 = self.add_virtual_nonnative_target_sized::(4); + let k1_neg = self.add_virtual_bool_target(); + let k2_neg = self.add_virtual_bool_target(); + + self.add_simple_generator(GLVDecompositionGenerator:: { + k: k.clone(), + k1: k1.clone(), + k2: k2.clone(), + k1_neg, + k2_neg, + _phantom: PhantomData, + }); + + // Check that `k1_raw + GLV_S * k2_raw == k`. + let k1_raw = self.nonnative_conditional_neg(&k1, k1_neg); + let k2_raw = self.nonnative_conditional_neg(&k2, k2_neg); + let s = self.constant_nonnative(GLV_S); + let mut should_be_k = self.mul_nonnative(&s, &k2_raw); + should_be_k = self.add_nonnative(&should_be_k, &k1_raw); + self.connect_nonnative(&should_be_k, k); + + (k1, k2, k1_neg, k2_neg) + } + + fn glv_mul( + &mut self, + p: &AffinePointTarget, + k: &NonNativeTarget, + ) -> AffinePointTarget { + let (k1, k2, k1_neg, k2_neg) = self.decompose_secp256k1_scalar(k); + + let beta = self.secp256k1_glv_beta(); + let beta_px = self.mul_nonnative(&beta, &p.x); + let sp = AffinePointTarget:: { + x: beta_px, + y: p.y.clone(), + }; + + let p_neg = self.curve_conditional_neg(p, k1_neg); + let sp_neg = self.curve_conditional_neg(&sp, k2_neg); + curve_msm_circuit(self, &p_neg, &sp_neg, &k1, &k2) + } +} + +#[derive(Debug)] +struct GLVDecompositionGenerator, const D: usize> { + k: NonNativeTarget, + k1: NonNativeTarget, + k2: NonNativeTarget, + k1_neg: BoolTarget, + k2_neg: BoolTarget, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for GLVDecompositionGenerator +{ + fn dependencies(&self) -> Vec { + self.k.value.limbs.iter().map(|l| l.0).collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let k = Secp256K1Scalar::from_biguint(witness_get_biguint_target( + witness, + self.k.value.clone(), + )); + + let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); + + buffer_set_biguint_target(out_buffer, &self.k1.value, &k1.to_canonical_biguint()); + buffer_set_biguint_target(out_buffer, &self.k2.value, &k2.to_canonical_biguint()); + out_buffer.set_bool_target(self.k1_neg, k1_neg); + out_buffer.set_bool_target(self.k2_neg, k2_neg); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::glv::glv_mul; + use crate::curve::secp256k1::Secp256K1; + use crate::gadgets::curve::CircuitBuilderCurve; + use crate::gadgets::glv::CircuitBuilderGlv; + use crate::gadgets::nonnative::CircuitBuilderNonNative; + + #[test] + #[ignore] + fn test_glv_gadget() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let rando = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let randot = builder.constant_affine_point(rando); + + let scalar = Secp256K1Scalar::rand(); + let scalar_target = builder.constant_nonnative(scalar); + + let rando_glv_scalar = glv_mul(rando.to_projective(), scalar); + let expected = builder.constant_affine_point(rando_glv_scalar.to_affine()); + let actual = builder.glv_mul(&randot, &scalar_target); + builder.connect_affine_point(&expected, &actual); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } +} diff --git a/ecdsa/src/gadgets/mod.rs b/ecdsa/src/gadgets/mod.rs new file mode 100644 index 00000000..35b10100 --- /dev/null +++ b/ecdsa/src/gadgets/mod.rs @@ -0,0 +1,9 @@ +pub mod biguint; +pub mod curve; +pub mod curve_fixed_base; +pub mod curve_msm; +pub mod curve_windowed_mul; +pub mod ecdsa; +pub mod glv; +pub mod nonnative; +pub mod split_nonnative; diff --git a/plonky2/src/gadgets/nonnative.rs b/ecdsa/src/gadgets/nonnative.rs similarity index 72% rename from plonky2/src/gadgets/nonnative.rs rename to ecdsa/src/gadgets/nonnative.rs index 3f8d29e8..76f23f3f 100644 --- a/plonky2/src/gadgets/nonnative.rs +++ b/ecdsa/src/gadgets/nonnative.rs @@ -1,17 +1,19 @@ use std::marker::PhantomData; use num::{BigUint, Integer, One, Zero}; +use plonky2::gadgets::arithmetic_u32::U32Target; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::iop::witness::PartitionWitness; +use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2_field::field_types::PrimeField; use plonky2_field::{extension_field::Extendable, field_types::Field}; use plonky2_util::ceil_div_usize; -use crate::gadgets::arithmetic_u32::U32Target; -use crate::gadgets::biguint::BigUintTarget; -use crate::hash::hash_types::RichField; -use crate::iop::generator::{GeneratedValues, SimpleGenerator}; -use crate::iop::target::{BoolTarget, Target}; -use crate::iop::witness::{PartitionWitness, Witness}; -use crate::plonk::circuit_builder::CircuitBuilder; +use crate::gadgets::biguint::{ + buffer_set_biguint_target, witness_get_biguint_target, BigUintTarget, CircuitBuilderBiguint, +}; #[derive(Clone, Debug)] pub struct NonNativeTarget { @@ -19,33 +21,131 @@ pub struct NonNativeTarget { pub(crate) _phantom: PhantomData, } -impl, const D: usize> CircuitBuilder { +pub trait CircuitBuilderNonNative, const D: usize> { fn num_nonnative_limbs() -> usize { ceil_div_usize(FF::BITS, 32) } - pub fn biguint_to_nonnative(&mut self, x: &BigUintTarget) -> NonNativeTarget { + fn biguint_to_nonnative(&mut self, x: &BigUintTarget) -> NonNativeTarget; + + fn nonnative_to_canonical_biguint( + &mut self, + x: &NonNativeTarget, + ) -> BigUintTarget; + + fn constant_nonnative(&mut self, x: FF) -> NonNativeTarget; + + fn zero_nonnative(&mut self) -> NonNativeTarget; + + // Assert that two NonNativeTarget's, both assumed to be in reduced form, are equal. + fn connect_nonnative( + &mut self, + lhs: &NonNativeTarget, + rhs: &NonNativeTarget, + ); + + fn add_virtual_nonnative_target(&mut self) -> NonNativeTarget; + + fn add_virtual_nonnative_target_sized( + &mut self, + num_limbs: usize, + ) -> NonNativeTarget; + + fn add_nonnative( + &mut self, + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget; + + fn mul_nonnative_by_bool( + &mut self, + a: &NonNativeTarget, + b: BoolTarget, + ) -> NonNativeTarget; + + fn if_nonnative( + &mut self, + b: BoolTarget, + x: &NonNativeTarget, + y: &NonNativeTarget, + ) -> NonNativeTarget; + + fn add_many_nonnative( + &mut self, + to_add: &[NonNativeTarget], + ) -> NonNativeTarget; + + // Subtract two `NonNativeTarget`s. + fn sub_nonnative( + &mut self, + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget; + + fn mul_nonnative( + &mut self, + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget; + + fn mul_many_nonnative( + &mut self, + to_mul: &[NonNativeTarget], + ) -> NonNativeTarget; + + fn neg_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget; + + fn inv_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget; + + /// Returns `x % |FF|` as a `NonNativeTarget`. + fn reduce(&mut self, x: &BigUintTarget) -> NonNativeTarget; + + fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget; + + fn bool_to_nonnative(&mut self, b: &BoolTarget) -> NonNativeTarget; + + // Split a nonnative field element to bits. + fn split_nonnative_to_bits(&mut self, x: &NonNativeTarget) -> Vec; + + fn nonnative_conditional_neg( + &mut self, + x: &NonNativeTarget, + b: BoolTarget, + ) -> NonNativeTarget; +} + +impl, const D: usize> CircuitBuilderNonNative + for CircuitBuilder +{ + fn num_nonnative_limbs() -> usize { + ceil_div_usize(FF::BITS, 32) + } + + fn biguint_to_nonnative(&mut self, x: &BigUintTarget) -> NonNativeTarget { NonNativeTarget { value: x.clone(), _phantom: PhantomData, } } - pub fn nonnative_to_biguint(&mut self, x: &NonNativeTarget) -> BigUintTarget { + fn nonnative_to_canonical_biguint( + &mut self, + x: &NonNativeTarget, + ) -> BigUintTarget { x.value.clone() } - pub fn constant_nonnative(&mut self, x: FF) -> NonNativeTarget { + fn constant_nonnative(&mut self, x: FF) -> NonNativeTarget { let x_biguint = self.constant_biguint(&x.to_canonical_biguint()); self.biguint_to_nonnative(&x_biguint) } - pub fn zero_nonnative(&mut self) -> NonNativeTarget { + fn zero_nonnative(&mut self) -> NonNativeTarget { self.constant_nonnative(FF::ZERO) } // Assert that two NonNativeTarget's, both assumed to be in reduced form, are equal. - pub fn connect_nonnative( + fn connect_nonnative( &mut self, lhs: &NonNativeTarget, rhs: &NonNativeTarget, @@ -53,7 +153,7 @@ impl, const D: usize> CircuitBuilder { self.connect_biguint(&lhs.value, &rhs.value); } - pub fn add_virtual_nonnative_target(&mut self) -> NonNativeTarget { + fn add_virtual_nonnative_target(&mut self) -> NonNativeTarget { let num_limbs = Self::num_nonnative_limbs::(); let value = self.add_virtual_biguint_target(num_limbs); @@ -63,7 +163,19 @@ impl, const D: usize> CircuitBuilder { } } - pub fn add_nonnative( + fn add_virtual_nonnative_target_sized( + &mut self, + num_limbs: usize, + ) -> NonNativeTarget { + let value = self.add_virtual_biguint_target(num_limbs); + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + + fn add_nonnative( &mut self, a: &NonNativeTarget, b: &NonNativeTarget, @@ -95,7 +207,7 @@ impl, const D: usize> CircuitBuilder { sum } - pub fn mul_nonnative_by_bool( + fn mul_nonnative_by_bool( &mut self, a: &NonNativeTarget, b: BoolTarget, @@ -106,7 +218,19 @@ impl, const D: usize> CircuitBuilder { } } - pub fn add_many_nonnative( + fn if_nonnative( + &mut self, + b: BoolTarget, + x: &NonNativeTarget, + y: &NonNativeTarget, + ) -> NonNativeTarget { + let not_b = self.not(b); + let maybe_x = self.mul_nonnative_by_bool(x, b); + let maybe_y = self.mul_nonnative_by_bool(y, not_b); + self.add_nonnative(&maybe_x, &maybe_y) + } + + fn add_many_nonnative( &mut self, to_add: &[NonNativeTarget], ) -> NonNativeTarget { @@ -150,7 +274,7 @@ impl, const D: usize> CircuitBuilder { } // Subtract two `NonNativeTarget`s. - pub fn sub_nonnative( + fn sub_nonnative( &mut self, a: &NonNativeTarget, b: &NonNativeTarget, @@ -178,7 +302,7 @@ impl, const D: usize> CircuitBuilder { diff } - pub fn mul_nonnative( + fn mul_nonnative( &mut self, a: &NonNativeTarget, b: &NonNativeTarget, @@ -209,7 +333,7 @@ impl, const D: usize> CircuitBuilder { prod } - pub fn mul_many_nonnative( + fn mul_many_nonnative( &mut self, to_mul: &[NonNativeTarget], ) -> NonNativeTarget { @@ -218,26 +342,20 @@ impl, const D: usize> CircuitBuilder { } let mut accumulator = self.mul_nonnative(&to_mul[0], &to_mul[1]); - for i in 2..to_mul.len() { - accumulator = self.mul_nonnative(&accumulator, &to_mul[i]); + for t in to_mul.iter().skip(2) { + accumulator = self.mul_nonnative(&accumulator, t); } accumulator } - pub fn neg_nonnative( - &mut self, - x: &NonNativeTarget, - ) -> NonNativeTarget { + fn neg_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { let zero_target = self.constant_biguint(&BigUint::zero()); let zero_ff = self.biguint_to_nonnative(&zero_target); self.sub_nonnative(&zero_ff, x) } - pub fn inv_nonnative( - &mut self, - x: &NonNativeTarget, - ) -> NonNativeTarget { + fn inv_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { let num_limbs = x.value.num_limbs(); let inv_biguint = self.add_virtual_biguint_target(num_limbs); let div = self.add_virtual_biguint_target(num_limbs); @@ -275,12 +393,12 @@ impl, const D: usize> CircuitBuilder { } } - pub fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { - let x_biguint = self.nonnative_to_biguint(x); + fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { + let x_biguint = self.nonnative_to_canonical_biguint(x); self.reduce(&x_biguint) } - pub fn bool_to_nonnative(&mut self, b: &BoolTarget) -> NonNativeTarget { + fn bool_to_nonnative(&mut self, b: &BoolTarget) -> NonNativeTarget { let limbs = vec![U32Target(b.target)]; let value = BigUintTarget { limbs }; @@ -291,10 +409,7 @@ impl, const D: usize> CircuitBuilder { } // Split a nonnative field element to bits. - pub fn split_nonnative_to_bits( - &mut self, - x: &NonNativeTarget, - ) -> Vec { + fn split_nonnative_to_bits(&mut self, x: &NonNativeTarget) -> Vec { let num_limbs = x.value.num_limbs(); let mut result = Vec::with_capacity(num_limbs * 32); @@ -311,6 +426,19 @@ impl, const D: usize> CircuitBuilder { result } + + fn nonnative_conditional_neg( + &mut self, + x: &NonNativeTarget, + b: BoolTarget, + ) -> NonNativeTarget { + let not_b = self.not(b); + let neg = self.neg_nonnative(x); + let x_if_true = self.mul_nonnative_by_bool(&neg, b); + let x_if_false = self.mul_nonnative_by_bool(x, not_b); + + self.add_nonnative(&x_if_true, &x_if_false) + } } #[derive(Debug)] @@ -337,8 +465,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = witness.get_nonnative_target(self.a.clone()); - let b = witness.get_nonnative_target(self.b.clone()); + let a = FF::from_biguint(witness_get_biguint_target(witness, self.a.value.clone())); + let b = FF::from_biguint(witness_get_biguint_target(witness, self.b.value.clone())); let a_biguint = a.to_canonical_biguint(); let b_biguint = b.to_canonical_biguint(); let sum_biguint = a_biguint + b_biguint; @@ -349,7 +477,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat (false, sum_biguint) }; - out_buffer.set_biguint_target(self.sum.value.clone(), sum_reduced); + buffer_set_biguint_target(out_buffer, &self.sum.value, &sum_reduced); out_buffer.set_bool_target(self.overflow, overflow); } } @@ -377,7 +505,9 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat let summands: Vec<_> = self .summands .iter() - .map(|summand| witness.get_nonnative_target(summand.clone())) + .map(|summand| { + FF::from_biguint(witness_get_biguint_target(witness, summand.value.clone())) + }) .collect(); let summand_biguints: Vec<_> = summands .iter() @@ -392,7 +522,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat let (overflow_biguint, sum_reduced) = sum_biguint.div_rem(&modulus); let overflow = overflow_biguint.to_u64_digits()[0] as u32; - out_buffer.set_biguint_target(self.sum.value.clone(), sum_reduced); + buffer_set_biguint_target(out_buffer, &self.sum.value, &sum_reduced); out_buffer.set_u32_target(self.overflow, overflow); } } @@ -421,19 +551,19 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = witness.get_nonnative_target(self.a.clone()); - let b = witness.get_nonnative_target(self.b.clone()); + let a = FF::from_biguint(witness_get_biguint_target(witness, self.a.value.clone())); + let b = FF::from_biguint(witness_get_biguint_target(witness, self.b.value.clone())); let a_biguint = a.to_canonical_biguint(); let b_biguint = b.to_canonical_biguint(); let modulus = FF::order(); - let (diff_biguint, overflow) = if a_biguint > b_biguint { + let (diff_biguint, overflow) = if a_biguint >= b_biguint { (a_biguint - b_biguint, false) } else { (modulus + a_biguint - b_biguint, true) }; - out_buffer.set_biguint_target(self.diff.value.clone(), diff_biguint); + buffer_set_biguint_target(out_buffer, &self.diff.value, &diff_biguint); out_buffer.set_bool_target(self.overflow, overflow); } } @@ -462,8 +592,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let a = witness.get_nonnative_target(self.a.clone()); - let b = witness.get_nonnative_target(self.b.clone()); + let a = FF::from_biguint(witness_get_biguint_target(witness, self.a.value.clone())); + let b = FF::from_biguint(witness_get_biguint_target(witness, self.b.value.clone())); let a_biguint = a.to_canonical_biguint(); let b_biguint = b.to_canonical_biguint(); @@ -472,8 +602,8 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat let modulus = FF::order(); let (overflow_biguint, prod_reduced) = prod_biguint.div_rem(&modulus); - out_buffer.set_biguint_target(self.prod.value.clone(), prod_reduced); - out_buffer.set_biguint_target(self.overflow.clone(), overflow_biguint); + buffer_set_biguint_target(out_buffer, &self.prod.value, &prod_reduced); + buffer_set_biguint_target(out_buffer, &self.overflow, &overflow_biguint); } } @@ -493,7 +623,7 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat } fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let x = witness.get_nonnative_target(self.x.clone()); + let x = FF::from_biguint(witness_get_biguint_target(witness, self.x.value.clone())); let inv = x.inverse(); let x_biguint = x.to_canonical_biguint(); @@ -502,22 +632,22 @@ impl, const D: usize, FF: PrimeField> SimpleGenerat let modulus = FF::order(); let (div, _rem) = prod.div_rem(&modulus); - out_buffer.set_biguint_target(self.div.clone(), div); - out_buffer.set_biguint_target(self.inv.clone(), inv_biguint); + buffer_set_biguint_target(out_buffer, &self.div, &div); + buffer_set_biguint_target(out_buffer, &self.inv, &inv_biguint); } } #[cfg(test)] mod tests { use anyhow::Result; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use plonky2_field::field_types::{Field, PrimeField}; use plonky2_field::secp256k1_base::Secp256K1Base; - use crate::iop::witness::PartialWitness; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::plonk::verifier::verify; + use crate::gadgets::nonnative::CircuitBuilderNonNative; #[test] fn test_nonnative_add() -> Result<()> { @@ -543,7 +673,7 @@ mod tests { let data = builder.build::(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) + data.verify(proof) } #[test] @@ -583,7 +713,7 @@ mod tests { let data = builder.build::(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) + data.verify(proof) } #[test] @@ -613,7 +743,7 @@ mod tests { let data = builder.build::(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) + data.verify(proof) } #[test] @@ -639,7 +769,7 @@ mod tests { let data = builder.build::(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) + data.verify(proof) } #[test] @@ -663,7 +793,7 @@ mod tests { let data = builder.build::(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) + data.verify(proof) } #[test] @@ -687,6 +817,6 @@ mod tests { let data = builder.build::(); let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) + data.verify(proof) } } diff --git a/ecdsa/src/gadgets/split_nonnative.rs b/ecdsa/src/gadgets/split_nonnative.rs new file mode 100644 index 00000000..e79438f8 --- /dev/null +++ b/ecdsa/src/gadgets/split_nonnative.rs @@ -0,0 +1,131 @@ +use std::marker::PhantomData; + +use itertools::Itertools; +use plonky2::gadgets::arithmetic_u32::U32Target; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::target::Target; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2_field::extension_field::Extendable; +use plonky2_field::field_types::Field; + +use crate::gadgets::biguint::BigUintTarget; +use crate::gadgets::nonnative::NonNativeTarget; + +pub trait CircuitBuilderSplit, const D: usize> { + fn split_u32_to_4_bit_limbs(&mut self, val: U32Target) -> Vec; + + fn split_nonnative_to_4_bit_limbs( + &mut self, + val: &NonNativeTarget, + ) -> Vec; + + fn split_nonnative_to_2_bit_limbs( + &mut self, + val: &NonNativeTarget, + ) -> Vec; + + // Note: assumes its inputs are 4-bit limbs, and does not range-check. + fn recombine_nonnative_4_bit_limbs( + &mut self, + limbs: Vec, + ) -> NonNativeTarget; +} + +impl, const D: usize> CircuitBuilderSplit + for CircuitBuilder +{ + fn split_u32_to_4_bit_limbs(&mut self, val: U32Target) -> Vec { + let two_bit_limbs = self.split_le_base::<4>(val.0, 16); + let four = self.constant(F::from_canonical_usize(4)); + let combined_limbs = two_bit_limbs + .iter() + .tuples() + .map(|(&a, &b)| self.mul_add(b, four, a)) + .collect(); + + combined_limbs + } + + fn split_nonnative_to_4_bit_limbs( + &mut self, + val: &NonNativeTarget, + ) -> Vec { + val.value + .limbs + .iter() + .flat_map(|&l| self.split_u32_to_4_bit_limbs(l)) + .collect() + } + + fn split_nonnative_to_2_bit_limbs( + &mut self, + val: &NonNativeTarget, + ) -> Vec { + val.value + .limbs + .iter() + .flat_map(|&l| self.split_le_base::<4>(l.0, 16)) + .collect() + } + + // Note: assumes its inputs are 4-bit limbs, and does not range-check. + fn recombine_nonnative_4_bit_limbs( + &mut self, + limbs: Vec, + ) -> NonNativeTarget { + let base = self.constant_u32(1 << 4); + let u32_limbs = limbs + .chunks(8) + .map(|chunk| { + let mut combined_chunk = self.zero_u32(); + for i in (0..8).rev() { + let (low, _high) = self.mul_add_u32(combined_chunk, base, U32Target(chunk[i])); + combined_chunk = low; + } + combined_chunk + }) + .collect(); + + NonNativeTarget { + value: BigUintTarget { limbs: u32_limbs }, + _phantom: PhantomData, + } + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2_field::field_types::Field; + use plonky2_field::secp256k1_scalar::Secp256K1Scalar; + + use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; + use crate::gadgets::split_nonnative::CircuitBuilderSplit; + + #[test] + fn test_split_nonnative() -> Result<()> { + type FF = Secp256K1Scalar; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = FF::rand(); + let x_target = builder.constant_nonnative(x); + let split = builder.split_nonnative_to_4_bit_limbs(&x_target); + let combined: NonNativeTarget = + builder.recombine_nonnative_4_bit_limbs(split); + builder.connect_nonnative(&x_target, &combined); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } +} diff --git a/ecdsa/src/lib.rs b/ecdsa/src/lib.rs new file mode 100644 index 00000000..1de32647 --- /dev/null +++ b/ecdsa/src/lib.rs @@ -0,0 +1,4 @@ +#![allow(clippy::needless_range_loop)] + +pub mod curve; +pub mod gadgets; diff --git a/plonky2/src/fri/recursive_verifier.rs b/plonky2/src/fri/recursive_verifier.rs index 9b619ea8..f56e19ca 100644 --- a/plonky2/src/fri/recursive_verifier.rs +++ b/plonky2/src/fri/recursive_verifier.rs @@ -332,7 +332,8 @@ impl, const D: usize> CircuitBuilder { let x_index_within_coset = self.le_sum(x_index_within_coset_bits.iter()); // Check consistency with our old evaluation from the previous round. - self.random_access_extension(x_index_within_coset, old_eval, evals.clone()); + let new_eval = self.random_access_extension(x_index_within_coset, evals.clone()); + self.connect_extension(new_eval, old_eval); // Infer P(y) from {P(x)}_{x^arity=y}. old_eval = with_context!( diff --git a/plonky2/src/gadgets/arithmetic.rs b/plonky2/src/gadgets/arithmetic.rs index 11fc57bf..b7df3726 100644 --- a/plonky2/src/gadgets/arithmetic.rs +++ b/plonky2/src/gadgets/arithmetic.rs @@ -6,7 +6,9 @@ use plonky2_field::field_types::Field64; use crate::gates::arithmetic_base::ArithmeticGate; use crate::gates::exponentiation::ExponentiationGate; use crate::hash::hash_types::RichField; +use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; +use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; impl, const D: usize> CircuitBuilder { @@ -323,6 +325,60 @@ impl, const D: usize> CircuitBuilder { let res = self.sub(one, b.target); BoolTarget::new_unsafe(res) } + + pub fn and(&mut self, b1: BoolTarget, b2: BoolTarget) -> BoolTarget { + BoolTarget::new_unsafe(self.mul(b1.target, b2.target)) + } + + pub fn _if(&mut self, b: BoolTarget, x: Target, y: Target) -> Target { + let not_b = self.not(b); + let maybe_x = self.mul(b.target, x); + self.mul_add(not_b.target, y, maybe_x) + } + + pub fn is_equal(&mut self, x: Target, y: Target) -> BoolTarget { + let zero = self.zero(); + + let equal = self.add_virtual_bool_target(); + let not_equal = self.not(equal); + let inv = self.add_virtual_target(); + self.add_simple_generator(EqualityGenerator { x, y, equal, inv }); + + let diff = self.sub(x, y); + let not_equal_check = self.mul(equal.target, diff); + + let diff_normalized = self.mul(diff, inv); + let equal_check = self.sub(diff_normalized, not_equal.target); + + self.connect(not_equal_check, zero); + self.connect(equal_check, zero); + + equal + } +} + +#[derive(Debug)] +struct EqualityGenerator { + x: Target, + y: Target, + equal: BoolTarget, + inv: Target, +} + +impl SimpleGenerator for EqualityGenerator { + fn dependencies(&self) -> Vec { + vec![self.x, self.y] + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let x = witness.get_target(self.x); + let y = witness.get_target(self.y); + + let inv = if x != y { (x - y).inverse() } else { F::ZERO }; + + out_buffer.set_bool_target(self.equal, x == y); + out_buffer.set_target(self.inv, inv); + } } /// Represents a base arithmetic operation in the circuit. Used to memoize results. diff --git a/plonky2/src/gadgets/ecdsa.rs b/plonky2/src/gadgets/ecdsa.rs deleted file mode 100644 index 64f37e1f..00000000 --- a/plonky2/src/gadgets/ecdsa.rs +++ /dev/null @@ -1,104 +0,0 @@ -use std::marker::PhantomData; - -use crate::curve::curve_types::Curve; -use crate::field::extension_field::Extendable; -use crate::gadgets::curve::AffinePointTarget; -use crate::gadgets::nonnative::NonNativeTarget; -use crate::hash::hash_types::RichField; -use crate::plonk::circuit_builder::CircuitBuilder; - -#[derive(Clone, Debug)] -pub struct ECDSASecretKeyTarget(pub NonNativeTarget); - -#[derive(Clone, Debug)] -pub struct ECDSAPublicKeyTarget(pub AffinePointTarget); - -#[derive(Clone, Debug)] -pub struct ECDSASignatureTarget { - pub r: NonNativeTarget, - pub s: NonNativeTarget, -} - -impl, const D: usize> CircuitBuilder { - pub fn verify_message( - &mut self, - msg: NonNativeTarget, - sig: ECDSASignatureTarget, - pk: ECDSAPublicKeyTarget, - ) { - let ECDSASignatureTarget { r, s } = sig; - - self.curve_assert_valid(&pk.0); - - let c = self.inv_nonnative(&s); - let u1 = self.mul_nonnative(&msg, &c); - let u2 = self.mul_nonnative(&r, &c); - - let g = self.constant_affine_point(C::GENERATOR_AFFINE); - let point1 = self.curve_scalar_mul(&g, &u1); - let point2 = self.curve_scalar_mul(&pk.0, &u2); - let point = self.curve_add(&point1, &point2); - - let x = NonNativeTarget:: { - value: point.x.value, - _phantom: PhantomData, - }; - self.connect_nonnative(&r, &x); - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - - use crate::curve::curve_types::{Curve, CurveScalar}; - use crate::curve::ecdsa::{sign_message, ECDSAPublicKey, ECDSASecretKey, ECDSASignature}; - use crate::curve::secp256k1::Secp256K1; - use crate::field::field_types::Field; - use crate::field::secp256k1_scalar::Secp256K1Scalar; - use crate::gadgets::ecdsa::{ECDSAPublicKeyTarget, ECDSASignatureTarget}; - use crate::iop::witness::PartialWitness; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::plonk::verifier::verify; - - #[test] - #[ignore] - fn test_ecdsa_circuit() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - type Curve = Secp256K1; - - let config = CircuitConfig::standard_ecc_config(); - - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let msg = Secp256K1Scalar::rand(); - let msg_target = builder.constant_nonnative(msg); - - let sk = ECDSASecretKey::(Secp256K1Scalar::rand()); - let pk = ECDSAPublicKey((CurveScalar(sk.0) * Curve::GENERATOR_PROJECTIVE).to_affine()); - - let pk_target = ECDSAPublicKeyTarget(builder.constant_affine_point(pk.0)); - - let sig = sign_message(msg, sk); - - let ECDSASignature { r, s } = sig; - let r_target = builder.constant_nonnative(r); - let s_target = builder.constant_nonnative(s); - let sig_target = ECDSASignatureTarget { - r: r_target, - s: s_target, - }; - - builder.verify_message(msg_target, sig_target, pk_target); - - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - verify(proof, &data.verifier_only, &data.common) - } -} diff --git a/plonky2/src/gadgets/mod.rs b/plonky2/src/gadgets/mod.rs index ec4d1263..d8613337 100644 --- a/plonky2/src/gadgets/mod.rs +++ b/plonky2/src/gadgets/mod.rs @@ -1,13 +1,9 @@ pub mod arithmetic; pub mod arithmetic_extension; pub mod arithmetic_u32; -pub mod biguint; -pub mod curve; -pub mod ecdsa; pub mod hash; pub mod interpolation; pub mod multiple_comparison; -pub mod nonnative; pub mod polynomial; pub mod random_access; pub mod range_check; diff --git a/plonky2/src/gadgets/random_access.rs b/plonky2/src/gadgets/random_access.rs index 9518e9fa..ec3c889a 100644 --- a/plonky2/src/gadgets/random_access.rs +++ b/plonky2/src/gadgets/random_access.rs @@ -10,13 +10,15 @@ use crate::plonk::circuit_builder::CircuitBuilder; impl, const D: usize> CircuitBuilder { /// Checks that a `Target` matches a vector at a non-deterministic index. /// Note: `access_index` is not range-checked. - pub fn random_access(&mut self, access_index: Target, claimed_element: Target, v: Vec) { + pub fn random_access(&mut self, access_index: Target, v: Vec) -> Target { let vec_size = v.len(); let bits = log2_strict(vec_size); debug_assert!(vec_size > 0); if vec_size == 1 { - return self.connect(claimed_element, v[0]); + return v[0]; } + let claimed_element = self.add_virtual_target(); + let dummy_gate = RandomAccessGate::::new_from_config(&self.config, bits); let (gate_index, copy) = self.find_slot(dummy_gate, &[], &[]); @@ -34,6 +36,8 @@ impl, const D: usize> CircuitBuilder { claimed_element, Target::wire(gate_index, dummy_gate.wire_claimed_element(copy)), ); + + claimed_element } /// Checks that an `ExtensionTarget` matches a vector at a non-deterministic index. @@ -41,16 +45,13 @@ impl, const D: usize> CircuitBuilder { pub fn random_access_extension( &mut self, access_index: Target, - claimed_element: ExtensionTarget, v: Vec>, - ) { - for i in 0..D { - self.random_access( - access_index, - claimed_element.0[i], - v.iter().map(|et| et.0[i]).collect(), - ); - } + ) -> ExtensionTarget { + let v: Vec<_> = (0..D) + .map(|i| self.random_access(access_index, v.iter().map(|et| et.0[i]).collect())) + .collect(); + + ExtensionTarget(v.try_into().unwrap()) } } @@ -80,7 +81,8 @@ mod tests { for i in 0..len { let it = builder.constant(F::from_canonical_usize(i)); let elem = builder.constant_extension(vec[i]); - builder.random_access_extension(it, elem, v.clone()); + let res = builder.random_access_extension(it, v.clone()); + builder.connect_extension(elem, res); } let data = builder.build::(); diff --git a/plonky2/src/hash/merkle_proofs.rs b/plonky2/src/hash/merkle_proofs.rs index c3ebf406..c4188271 100644 --- a/plonky2/src/hash/merkle_proofs.rs +++ b/plonky2/src/hash/merkle_proofs.rs @@ -78,11 +78,9 @@ impl, const D: usize> CircuitBuilder { let index = self.le_sum(leaf_index_bits[proof.siblings.len()..].iter().copied()); for i in 0..4 { - self.random_access( - index, - state.elements[i], - merkle_cap.0.iter().map(|h| h.elements[i]).collect(), - ); + let result = + self.random_access(index, merkle_cap.0.iter().map(|h| h.elements[i]).collect()); + self.connect(result, state.elements[i]); } } @@ -110,11 +108,11 @@ impl, const D: usize> CircuitBuilder { } for i in 0..4 { - self.random_access( + let result = self.random_access( cap_index, - state.elements[i], merkle_cap.0.iter().map(|h| h.elements[i]).collect(), ); + self.connect(result, state.elements[i]); } } diff --git a/plonky2/src/iop/generator.rs b/plonky2/src/iop/generator.rs index 1569e889..f36ba3aa 100644 --- a/plonky2/src/iop/generator.rs +++ b/plonky2/src/iop/generator.rs @@ -1,13 +1,10 @@ use std::fmt::Debug; use std::marker::PhantomData; -use num::BigUint; use plonky2_field::extension_field::{Extendable, FieldExtension}; -use plonky2_field::field_types::{Field, PrimeField}; +use plonky2_field::field_types::Field; use crate::gadgets::arithmetic_u32::U32Target; -use crate::gadgets::biguint::BigUintTarget; -use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::{HashOut, HashOutTarget, RichField}; use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::{BoolTarget, Target}; @@ -169,21 +166,6 @@ impl GeneratedValues { self.set_target(target.0, F::from_canonical_u32(value)) } - pub fn set_biguint_target(&mut self, target: BigUintTarget, value: BigUint) { - let mut limbs = value.to_u32_digits(); - - assert!(target.num_limbs() >= limbs.len()); - - limbs.resize(target.num_limbs(), 0); - for i in 0..target.num_limbs() { - self.set_u32_target(target.get_limb(i), limbs[i]); - } - } - - pub fn set_nonnative_target(&mut self, target: NonNativeTarget, value: FF) { - self.set_biguint_target(target.value, value.to_canonical_biguint()) - } - pub fn set_hash_target(&mut self, ht: HashOutTarget, value: HashOut) { ht.elements .iter() diff --git a/plonky2/src/iop/witness.rs b/plonky2/src/iop/witness.rs index de6d4a05..4a565d02 100644 --- a/plonky2/src/iop/witness.rs +++ b/plonky2/src/iop/witness.rs @@ -1,15 +1,12 @@ use std::collections::HashMap; use itertools::Itertools; -use num::{BigUint, FromPrimitive, Zero}; use plonky2_field::extension_field::{Extendable, FieldExtension}; -use plonky2_field::field_types::{Field, PrimeField}; +use plonky2_field::field_types::Field; use crate::fri::structure::{FriOpenings, FriOpeningsTarget}; use crate::fri::witness_util::set_fri_proof_target; use crate::gadgets::arithmetic_u32::U32Target; -use crate::gadgets::biguint::BigUintTarget; -use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::HashOutTarget; use crate::hash::hash_types::RichField; use crate::hash::hash_types::{HashOut, MerkleCapTarget}; @@ -63,30 +60,6 @@ pub trait Witness { panic!("not a bool") } - fn get_biguint_target(&self, target: BigUintTarget) -> BigUint - where - F: PrimeField, - { - let mut result = BigUint::zero(); - - let limb_base = BigUint::from_u64(1 << 32u64).unwrap(); - for i in (0..target.num_limbs()).rev() { - let limb = target.get_limb(i); - result *= &limb_base; - result += self.get_target(limb.0).to_canonical_biguint(); - } - - result - } - - fn get_nonnative_target(&self, target: NonNativeTarget) -> FF - where - F: PrimeField, - { - let val = self.get_biguint_target(target.value); - FF::from_biguint(val) - } - fn get_hash_target(&self, ht: HashOutTarget) -> HashOut { HashOut { elements: self.get_targets(&ht.elements).try_into().unwrap(), @@ -159,12 +132,6 @@ pub trait Witness { self.set_target(target.0, F::from_canonical_u32(value)) } - fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) { - for (<, &l) in target.limbs.iter().zip(&value.to_u32_digits()) { - self.set_u32_target(lt, l); - } - } - /// Set the targets in a `ProofWithPublicInputsTarget` to their corresponding values in a /// `ProofWithPublicInputs`. fn set_proof_with_pis_target, const D: usize>( diff --git a/plonky2/src/lib.rs b/plonky2/src/lib.rs index e5e77bb9..1502cea9 100644 --- a/plonky2/src/lib.rs +++ b/plonky2/src/lib.rs @@ -12,7 +12,6 @@ pub use plonky2_field as field; -pub mod curve; pub mod fri; pub mod gadgets; pub mod gates; diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index 3d4ee2df..34b38fcf 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -86,6 +86,13 @@ impl CircuitConfig { } } + pub fn wide_ecc_config() -> Self { + Self { + num_wires: 234, + ..Self::standard_recursion_config() + } + } + pub fn standard_recursion_zk_config() -> Self { CircuitConfig { zero_knowledge: true,