diff --git a/insertion/src/insertion_gate.rs b/insertion/src/insertion_gate.rs index 8ee60483..442416d3 100644 --- a/insertion/src/insertion_gate.rs +++ b/insertion/src/insertion_gate.rs @@ -404,7 +404,7 @@ mod tests { v.extend(equality_dummy_vals); v.extend(insert_here_vals); - v.iter().map(|&x| x.into()).collect::>() + v.iter().map(|&x| x.into()).collect() } let orig_vec = vec![FF::rand(); 3]; diff --git a/plonky2/src/curve/curve_types.rs b/plonky2/src/curve/curve_types.rs index b7ee34e6..9599f6fe 100644 --- a/plonky2/src/curve/curve_types.rs +++ b/plonky2/src/curve/curve_types.rs @@ -259,3 +259,11 @@ impl Neg for ProjectivePoint { ProjectivePoint { x, y: -y, z } } } + +pub fn base_to_scalar(x: C::BaseField) -> C::ScalarField { + C::ScalarField::from_biguint(x.to_biguint()) +} + +pub fn scalar_to_base(x: C::ScalarField) -> C::BaseField { + C::BaseField::from_biguint(x.to_biguint()) +} diff --git a/plonky2/src/curve/ecdsa.rs b/plonky2/src/curve/ecdsa.rs new file mode 100644 index 00000000..c84c4c10 --- /dev/null +++ b/plonky2/src/curve/ecdsa.rs @@ -0,0 +1,72 @@ +use crate::curve::curve_msm::msm_parallel; +use crate::curve::curve_types::{base_to_scalar, AffinePoint, Curve, CurveScalar}; +use crate::field::field_types::Field; + +pub struct ECDSASignature { + pub r: C::ScalarField, + pub s: C::ScalarField, +} + +pub struct ECDSASecretKey(pub C::ScalarField); +pub struct ECDSAPublicKey(pub AffinePoint); + +pub fn sign_message(msg: C::ScalarField, sk: ECDSASecretKey) -> ECDSASignature { + let (k, rr) = { + let mut k = C::ScalarField::rand(); + let mut rr = (CurveScalar(k) * C::GENERATOR_PROJECTIVE).to_affine(); + while rr.x == C::BaseField::ZERO { + k = C::ScalarField::rand(); + rr = (CurveScalar(k) * C::GENERATOR_PROJECTIVE).to_affine(); + } + (k, rr) + }; + let r = base_to_scalar::(rr.x); + + let s = k.inverse() * (msg + r * sk.0); + + ECDSASignature { r, s } +} + +pub fn verify_message( + msg: C::ScalarField, + sig: ECDSASignature, + pk: ECDSAPublicKey, +) -> bool { + let ECDSASignature { r, s } = sig; + + assert!(pk.0.is_valid()); + + let c = s.inverse(); + let u1 = msg * c; + let u2 = r * c; + + let g = C::GENERATOR_PROJECTIVE; + let w = 5; // Experimentally fastest + let point_proj = msm_parallel(&[u1, u2], &[g, pk.0.to_projective()], w); + let point = point_proj.to_affine(); + + let x = base_to_scalar::(point.x); + r == x +} + +#[cfg(test)] +mod tests { + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::ecdsa::{sign_message, verify_message, ECDSAPublicKey, ECDSASecretKey}; + use crate::curve::secp256k1::Secp256K1; + use crate::field::field_types::Field; + use crate::field::secp256k1_scalar::Secp256K1Scalar; + + #[test] + fn test_ecdsa_native() { + type C = Secp256K1; + + let msg = Secp256K1Scalar::rand(); + let sk = ECDSASecretKey(Secp256K1Scalar::rand()); + let pk = ECDSAPublicKey((CurveScalar(sk.0) * C::GENERATOR_PROJECTIVE).to_affine()); + + let sig = sign_message(msg, sk); + let result = verify_message(msg, sig, pk); + assert!(result); + } +} diff --git a/plonky2/src/curve/mod.rs b/plonky2/src/curve/mod.rs index d31e373e..8dd6f0d6 100644 --- a/plonky2/src/curve/mod.rs +++ b/plonky2/src/curve/mod.rs @@ -3,4 +3,5 @@ pub mod curve_msm; pub mod curve_multiplication; pub mod curve_summation; pub mod curve_types; +pub mod ecdsa; pub mod secp256k1; diff --git a/plonky2/src/gadgets/arithmetic.rs b/plonky2/src/gadgets/arithmetic.rs index e280fa5e..90f3b090 100644 --- a/plonky2/src/gadgets/arithmetic.rs +++ b/plonky2/src/gadgets/arithmetic.rs @@ -315,6 +315,12 @@ impl, const D: usize> CircuitBuilder { let x_ext = self.convert_to_ext(x); self.inverse_extension(x_ext).0[0] } + + pub fn not(&mut self, b: BoolTarget) -> BoolTarget { + let one = self.one(); + let res = self.sub(one, b.target); + BoolTarget::new_unsafe(res) + } } /// Represents a base arithmetic operation in the circuit. Used to memoize results. diff --git a/plonky2/src/gadgets/arithmetic_u32.rs b/plonky2/src/gadgets/arithmetic_u32.rs index 6116f61b..8a0d5784 100644 --- a/plonky2/src/gadgets/arithmetic_u32.rs +++ b/plonky2/src/gadgets/arithmetic_u32.rs @@ -1,9 +1,14 @@ +use std::marker::PhantomData; + use plonky2_field::extension_field::Extendable; +use crate::gates::add_many_u32::U32AddManyGate; use crate::gates::arithmetic_u32::U32ArithmeticGate; use crate::gates::subtraction_u32::U32SubtractionGate; use crate::hash::hash_types::RichField; +use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; #[derive(Clone, Copy, Debug)] @@ -113,18 +118,57 @@ impl, const D: usize> CircuitBuilder { 1 => (to_add[0], self.zero_u32()), 2 => self.add_u32(to_add[0], to_add[1]), _ => { - let (mut low, mut carry) = self.add_u32(to_add[0], to_add[1]); - for i in 2..to_add.len() { - let (new_low, new_carry) = self.add_u32(to_add[i], low); - let (combined_carry, _zero) = self.add_u32(carry, new_carry); - low = new_low; - carry = combined_carry; + let num_addends = to_add.len(); + let gate = U32AddManyGate::::new_from_config(&self.config, num_addends); + let (gate_index, copy) = self.find_u32_add_many_gate(num_addends); + + for j in 0..num_addends { + self.connect( + Target::wire(gate_index, gate.wire_ith_op_jth_addend(copy, j)), + to_add[j].0, + ); } - (low, carry) + let zero = self.zero(); + self.connect(Target::wire(gate_index, gate.wire_ith_carry(copy)), zero); + + let output_low = + U32Target(Target::wire(gate_index, gate.wire_ith_output_result(copy))); + let output_high = + U32Target(Target::wire(gate_index, gate.wire_ith_output_carry(copy))); + + (output_low, output_high) } } } + pub fn add_u32s_with_carry( + &mut self, + to_add: &[U32Target], + carry: U32Target, + ) -> (U32Target, U32Target) { + if to_add.len() == 1 { + return self.add_u32(to_add[0], carry); + } + + let num_addends = to_add.len(); + + let gate = U32AddManyGate::::new_from_config(&self.config, num_addends); + let (gate_index, copy) = self.find_u32_add_many_gate(num_addends); + + for j in 0..num_addends { + self.connect( + Target::wire(gate_index, gate.wire_ith_op_jth_addend(copy, j)), + to_add[j].0, + ); + } + self.connect(Target::wire(gate_index, gate.wire_ith_carry(copy)), carry.0); + + let output = U32Target(Target::wire(gate_index, gate.wire_ith_output_result(copy))); + let output_carry = U32Target(Target::wire(gate_index, gate.wire_ith_output_carry(copy))); + + (output, output_carry) + } + pub fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { let zero = self.zero_u32(); self.mul_add_u32(a, b, zero) @@ -153,3 +197,75 @@ impl, const D: usize> CircuitBuilder { (output_result, output_borrow) } } + +#[derive(Debug)] +struct SplitToU32Generator, const D: usize> { + x: Target, + low: U32Target, + high: U32Target, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for SplitToU32Generator +{ + fn dependencies(&self) -> Vec { + vec![self.x] + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let x = witness.get_target(self.x); + let x_u64 = x.to_canonical_u64(); + let low = x_u64 as u32; + let high = (x_u64 >> 32) as u32; + + out_buffer.set_u32_target(self.low, low); + out_buffer.set_u32_target(self.high, high); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use rand::{thread_rng, Rng}; + + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use crate::plonk::verifier::verify; + + #[test] + pub fn test_add_many_u32s() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + const NUM_ADDENDS: usize = 15; + + let config = CircuitConfig::standard_recursion_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let mut rng = thread_rng(); + let mut to_add = Vec::new(); + let mut sum = 0u64; + for _ in 0..NUM_ADDENDS { + let x: u32 = rng.gen(); + sum += x as u64; + to_add.push(builder.constant_u32(x)); + } + let carry = builder.zero_u32(); + let (result_low, result_high) = builder.add_u32s_with_carry(&to_add, carry); + let expected_low = builder.constant_u32((sum % (1 << 32)) as u32); + let expected_high = builder.constant_u32((sum >> 32) as u32); + + builder.connect_u32(result_low, expected_low); + builder.connect_u32(result_high, expected_high); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } +} diff --git a/plonky2/src/gadgets/biguint.rs b/plonky2/src/gadgets/biguint.rs index 77013d27..c9ad7280 100644 --- a/plonky2/src/gadgets/biguint.rs +++ b/plonky2/src/gadgets/biguint.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use num::{BigUint, Integer}; +use num::{BigUint, Integer, Zero}; use plonky2_field::extension_field::Extendable; use crate::gadgets::arithmetic_u32::U32Target; @@ -33,6 +33,10 @@ impl, const D: usize> CircuitBuilder { BigUintTarget { limbs } } + pub fn zero_biguint(&mut self) -> BigUintTarget { + self.constant_biguint(&BigUint::zero()) + } + pub fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) { let min_limbs = lhs.num_limbs().min(rhs.num_limbs()); for i in 0..min_limbs { @@ -76,9 +80,7 @@ impl, const D: usize> CircuitBuilder { } pub fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget { - let limbs = (0..num_limbs) - .map(|_| self.add_virtual_u32_target()) - .collect(); + let limbs = self.add_virtual_u32_targets(num_limbs); BigUintTarget { limbs } } @@ -143,8 +145,7 @@ impl, const D: usize> CircuitBuilder { let mut combined_limbs = vec![]; let mut carry = self.zero_u32(); for summands in &mut to_add { - summands.push(carry); - let (new_result, new_carry) = self.add_many_u32(summands); + let (new_result, new_carry) = self.add_u32s_with_carry(summands, carry); combined_limbs.push(new_result); carry = new_carry; } @@ -155,6 +156,18 @@ impl, const D: usize> CircuitBuilder { } } + pub fn mul_biguint_by_bool(&mut self, a: &BigUintTarget, b: BoolTarget) -> BigUintTarget { + let t = b.target; + + BigUintTarget { + limbs: a + .limbs + .iter() + .map(|&l| U32Target(self.mul(l.0, t))) + .collect(), + } + } + // Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function). pub fn mul_add_biguint( &mut self, diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index 63e96721..2ff22319 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -104,29 +104,17 @@ impl, const D: usize> CircuitBuilder { let AffinePointTarget { x: x2, y: y2 } = p2; let u = self.sub_nonnative(y2, y1); - let uu = self.mul_nonnative(&u, &u); let v = self.sub_nonnative(x2, x1); - let vv = self.mul_nonnative(&v, &v); - let vvv = self.mul_nonnative(&v, &vv); - let r = self.mul_nonnative(&vv, x1); - let diff = self.sub_nonnative(&uu, &vvv); - let r2 = self.add_nonnative(&r, &r); - let a = self.sub_nonnative(&diff, &r2); - let x3 = self.mul_nonnative(&v, &a); + let v_inv = self.inv_nonnative(&v); + let s = self.mul_nonnative(&u, &v_inv); + let s_squared = self.mul_nonnative(&s, &s); + let x_sum = self.add_nonnative(x2, x1); + let x3 = self.sub_nonnative(&s_squared, &x_sum); + let x_diff = self.sub_nonnative(x1, &x3); + let prod = self.mul_nonnative(&s, &x_diff); + let y3 = self.sub_nonnative(&prod, y1); - let r_a = self.sub_nonnative(&r, &a); - let y3_first = self.mul_nonnative(&u, &r_a); - let y3_second = self.mul_nonnative(&vvv, y1); - let y3 = self.sub_nonnative(&y3_first, &y3_second); - - let z3_inv = self.inv_nonnative(&vvv); - let x3_norm = self.mul_nonnative(&x3, &z3_inv); - let y3_norm = self.mul_nonnative(&y3, &z3_inv); - - AffinePointTarget { - x: x3_norm, - y: y3_norm, - } + AffinePointTarget { x: x3, y: y3 } } pub fn curve_scalar_mul( @@ -134,11 +122,7 @@ impl, const D: usize> CircuitBuilder { p: &AffinePointTarget, n: &NonNativeTarget, ) -> AffinePointTarget { - let one = self.constant_nonnative(C::BaseField::ONE); - let bits = self.split_nonnative_to_bits(n); - let bits_as_base: Vec> = - bits.iter().map(|b| self.bool_to_nonnative(b)).collect(); let rando = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine(); let randot = self.constant_affine_point(rando); @@ -149,15 +133,15 @@ impl, const D: usize> CircuitBuilder { let mut two_i_times_p = self.add_virtual_affine_point_target(); self.connect_affine_point(p, &two_i_times_p); - for bit in bits_as_base.iter() { - let not_bit = self.sub_nonnative(&one, bit); + for &bit in bits.iter() { + let not_bit = self.not(bit); let result_plus_2_i_p = self.curve_add(&result, &two_i_times_p); - let new_x_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p.x); - let new_x_if_not_bit = self.mul_nonnative(¬_bit, &result.x); - let new_y_if_bit = self.mul_nonnative(bit, &result_plus_2_i_p.y); - let new_y_if_not_bit = self.mul_nonnative(¬_bit, &result.y); + let new_x_if_bit = self.mul_nonnative_by_bool(&result_plus_2_i_p.x, bit); + let new_x_if_not_bit = self.mul_nonnative_by_bool(&result.x, not_bit); + let new_y_if_bit = self.mul_nonnative_by_bool(&result_plus_2_i_p.y, bit); + let new_y_if_not_bit = self.mul_nonnative_by_bool(&result.y, not_bit); let new_x = self.add_nonnative(&new_x_if_bit, &new_x_if_not_bit); let new_y = self.add_nonnative(&new_y_if_bit, &new_y_if_not_bit); @@ -177,6 +161,8 @@ impl, const D: usize> CircuitBuilder { #[cfg(test)] mod tests { + use std::ops::Neg; + use anyhow::Result; use plonky2_field::field_types::Field; use plonky2_field::secp256k1_base::Secp256K1Base; @@ -196,7 +182,7 @@ mod tests { type C = PoseidonGoldilocksConfig; type F = >::F; - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -221,7 +207,7 @@ mod tests { type C = PoseidonGoldilocksConfig; type F = >::F; - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -248,7 +234,7 @@ mod tests { type C = PoseidonGoldilocksConfig; type F = >::F; - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -285,7 +271,7 @@ mod tests { type C = PoseidonGoldilocksConfig; type F = >::F; - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -310,33 +296,30 @@ mod tests { } #[test] - #[ignore] fn test_curve_mul() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; - let config = CircuitConfig { - num_routed_wires: 33, - ..CircuitConfig::standard_recursion_config() - }; + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); let g = Secp256K1::GENERATOR_AFFINE; let five = Secp256K1Scalar::from_canonical_usize(5); - let five_scalar = CurveScalar::(five); - let five_g = (five_scalar * g.to_projective()).to_affine(); - let five_g_expected = builder.constant_affine_point(five_g); - builder.curve_assert_valid(&five_g_expected); + let neg_five = five.neg(); + let neg_five_scalar = CurveScalar::(neg_five); + let neg_five_g = (neg_five_scalar * g.to_projective()).to_affine(); + let neg_five_g_expected = builder.constant_affine_point(neg_five_g); + builder.curve_assert_valid(&neg_five_g_expected); let g_target = builder.constant_affine_point(g); - let five_target = builder.constant_nonnative(five); - let five_g_actual = builder.curve_scalar_mul(&g_target, &five_target); - builder.curve_assert_valid(&five_g_actual); + let neg_five_target = builder.constant_nonnative(neg_five); + let neg_five_g_actual = builder.curve_scalar_mul(&g_target, &neg_five_target); + builder.curve_assert_valid(&neg_five_g_actual); - builder.connect_affine_point(&five_g_expected, &five_g_actual); + builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual); let data = builder.build::(); let proof = data.prove(pw).unwrap(); @@ -345,16 +328,12 @@ mod tests { } #[test] - #[ignore] fn test_curve_random() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; - let config = CircuitConfig { - num_routed_wires: 33, - ..CircuitConfig::standard_recursion_config() - }; + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); diff --git a/plonky2/src/gadgets/ecdsa.rs b/plonky2/src/gadgets/ecdsa.rs new file mode 100644 index 00000000..eba04d85 --- /dev/null +++ b/plonky2/src/gadgets/ecdsa.rs @@ -0,0 +1,100 @@ +use std::marker::PhantomData; + +use crate::curve::curve_types::Curve; +use crate::field::extension_field::Extendable; +use crate::gadgets::curve::AffinePointTarget; +use crate::gadgets::nonnative::NonNativeTarget; +use crate::hash::hash_types::RichField; +use crate::plonk::circuit_builder::CircuitBuilder; + +pub struct ECDSASecretKeyTarget(NonNativeTarget); +pub struct ECDSAPublicKeyTarget(AffinePointTarget); + +pub struct ECDSASignatureTarget { + pub r: NonNativeTarget, + pub s: NonNativeTarget, +} + +impl, const D: usize> CircuitBuilder { + pub fn verify_message( + &mut self, + msg: NonNativeTarget, + sig: ECDSASignatureTarget, + pk: ECDSAPublicKeyTarget, + ) { + let ECDSASignatureTarget { r, s } = sig; + + self.curve_assert_valid(&pk.0); + + let c = self.inv_nonnative(&s); + let u1 = self.mul_nonnative(&msg, &c); + let u2 = self.mul_nonnative(&r, &c); + + let g = self.constant_affine_point(C::GENERATOR_AFFINE); + let point1 = self.curve_scalar_mul(&g, &u1); + let point2 = self.curve_scalar_mul(&pk.0, &u2); + let point = self.curve_add(&point1, &point2); + + let x = NonNativeTarget:: { + value: point.x.value, + _phantom: PhantomData, + }; + self.connect_nonnative(&r, &x); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::ecdsa::{sign_message, ECDSAPublicKey, ECDSASecretKey, ECDSASignature}; + use crate::curve::secp256k1::Secp256K1; + use crate::field::field_types::Field; + use crate::field::secp256k1_scalar::Secp256K1Scalar; + use crate::gadgets::ecdsa::{ECDSAPublicKeyTarget, ECDSASignatureTarget}; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use crate::plonk::verifier::verify; + + #[test] + #[ignore] + fn test_ecdsa_circuit() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + type Curve = Secp256K1; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let msg = Secp256K1Scalar::rand(); + let msg_target = builder.constant_nonnative(msg); + + let sk = ECDSASecretKey::(Secp256K1Scalar::rand()); + let pk = ECDSAPublicKey((CurveScalar(sk.0) * Curve::GENERATOR_PROJECTIVE).to_affine()); + + let pk_target = ECDSAPublicKeyTarget(builder.constant_affine_point(pk.0)); + + let sig = sign_message(msg, sk); + + let ECDSASignature { r, s } = sig; + let r_target = builder.constant_nonnative(r); + let s_target = builder.constant_nonnative(s); + let sig_target = ECDSASignatureTarget { + r: r_target, + s: s_target, + }; + + builder.verify_message(msg_target, sig_target, pk_target); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } +} diff --git a/plonky2/src/gadgets/mod.rs b/plonky2/src/gadgets/mod.rs index b73e2a7f..ec4d1263 100644 --- a/plonky2/src/gadgets/mod.rs +++ b/plonky2/src/gadgets/mod.rs @@ -3,6 +3,7 @@ pub mod arithmetic_extension; pub mod arithmetic_u32; pub mod biguint; pub mod curve; +pub mod ecdsa; pub mod hash; pub mod interpolation; pub mod multiple_comparison; diff --git a/plonky2/src/gadgets/multiple_comparison.rs b/plonky2/src/gadgets/multiple_comparison.rs index 88b94f3f..434fb6c1 100644 --- a/plonky2/src/gadgets/multiple_comparison.rs +++ b/plonky2/src/gadgets/multiple_comparison.rs @@ -60,8 +60,9 @@ impl, const D: usize> CircuitBuilder { /// Helper function for comparing, specifically, lists of `U32Target`s. pub fn list_le_u32(&mut self, a: Vec, b: Vec) -> BoolTarget { - let a_targets = a.iter().map(|&t| t.0).collect(); - let b_targets = b.iter().map(|&t| t.0).collect(); + let a_targets: Vec = a.iter().map(|&t| t.0).collect(); + let b_targets: Vec = b.iter().map(|&t| t.0).collect(); + self.list_le(a_targets, b_targets, 32) } } diff --git a/plonky2/src/gadgets/nonnative.rs b/plonky2/src/gadgets/nonnative.rs index 16fd022e..245b0403 100644 --- a/plonky2/src/gadgets/nonnative.rs +++ b/plonky2/src/gadgets/nonnative.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use num::{BigUint, Zero}; +use num::{BigUint, Integer, One, Zero}; use plonky2_field::{extension_field::Extendable, field_types::Field}; use plonky2_util::ceil_div_usize; @@ -15,7 +15,7 @@ use crate::plonk::circuit_builder::CircuitBuilder; #[derive(Clone, Debug)] pub struct NonNativeTarget { pub(crate) value: BigUintTarget, - _phantom: PhantomData, + pub(crate) _phantom: PhantomData, } impl, const D: usize> CircuitBuilder { @@ -39,6 +39,10 @@ impl, const D: usize> CircuitBuilder { self.biguint_to_nonnative(&x_biguint) } + pub fn zero_nonnative(&mut self) -> NonNativeTarget { + self.constant_nonnative(FF::ZERO) + } + // Assert that two NonNativeTarget's, both assumed to be in reduced form, are equal. pub fn connect_nonnative( &mut self, @@ -58,16 +62,90 @@ impl, const D: usize> CircuitBuilder { } } - // Add two `NonNativeTarget`s. pub fn add_nonnative( &mut self, a: &NonNativeTarget, b: &NonNativeTarget, ) -> NonNativeTarget { - let result = self.add_biguint(&a.value, &b.value); + let sum = self.add_virtual_nonnative_target::(); + let overflow = self.add_virtual_bool_target(); - // TODO: reduce add result with only one conditional subtraction - self.reduce(&result) + self.add_simple_generator(NonNativeAdditionGenerator:: { + a: a.clone(), + b: b.clone(), + sum: sum.clone(), + overflow, + _phantom: PhantomData, + }); + + let sum_expected = self.add_biguint(&a.value, &b.value); + + let modulus = self.constant_biguint(&FF::order()); + let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow); + let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow); + self.connect_biguint(&sum_expected, &sum_actual); + + // Range-check result. + // TODO: can potentially leave unreduced until necessary (e.g. when connecting values). + let cmp = self.cmp_biguint(&sum.value, &modulus); + let one = self.one(); + self.connect(cmp.target, one); + + sum + } + + pub fn mul_nonnative_by_bool( + &mut self, + a: &NonNativeTarget, + b: BoolTarget, + ) -> NonNativeTarget { + NonNativeTarget { + value: self.mul_biguint_by_bool(&a.value, b), + _phantom: PhantomData, + } + } + + pub fn add_many_nonnative( + &mut self, + to_add: &[NonNativeTarget], + ) -> NonNativeTarget { + if to_add.len() == 1 { + return to_add[0].clone(); + } + + let sum = self.add_virtual_nonnative_target::(); + let overflow = self.add_virtual_u32_target(); + let summands = to_add.to_vec(); + + self.add_simple_generator(NonNativeMultipleAddsGenerator:: { + summands: summands.clone(), + sum: sum.clone(), + overflow, + _phantom: PhantomData, + }); + + self.range_check_u32(sum.value.limbs.clone()); + self.range_check_u32(vec![overflow]); + + let sum_expected = summands + .iter() + .fold(self.zero_biguint(), |a, b| self.add_biguint(&a, &b.value)); + + let modulus = self.constant_biguint(&FF::order()); + let overflow_biguint = BigUintTarget { + limbs: vec![overflow], + }; + let mod_times_overflow = self.mul_biguint(&modulus, &overflow_biguint); + let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow); + self.connect_biguint(&sum_expected, &sum_actual); + + // Range-check result. + // TODO: can potentially leave unreduced until necessary (e.g. when connecting values). + let cmp = self.cmp_biguint(&sum.value, &modulus); + let one = self.one(); + self.connect(cmp.target, one); + + sum } // Subtract two `NonNativeTarget`s. @@ -76,12 +154,27 @@ impl, const D: usize> CircuitBuilder { a: &NonNativeTarget, b: &NonNativeTarget, ) -> NonNativeTarget { - let order = self.constant_biguint(&FF::order()); - let a_plus_order = self.add_biguint(&order, &a.value); - let result = self.sub_biguint(&a_plus_order, &b.value); + let diff = self.add_virtual_nonnative_target::(); + let overflow = self.add_virtual_bool_target(); - // TODO: reduce sub result with only one conditional addition? - self.reduce(&result) + self.add_simple_generator(NonNativeSubtractionGenerator:: { + a: a.clone(), + b: b.clone(), + diff: diff.clone(), + overflow, + _phantom: PhantomData, + }); + + self.range_check_u32(diff.value.limbs.clone()); + self.assert_bool(overflow); + + let diff_plus_b = self.add_biguint(&diff.value, &b.value); + let modulus = self.constant_biguint(&FF::order()); + let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow); + let diff_plus_b_reduced = self.sub_biguint(&diff_plus_b, &mod_times_overflow); + self.connect_biguint(&a.value, &diff_plus_b_reduced); + + diff } pub fn mul_nonnative( @@ -89,9 +182,45 @@ impl, const D: usize> CircuitBuilder { a: &NonNativeTarget, b: &NonNativeTarget, ) -> NonNativeTarget { - let result = self.mul_biguint(&a.value, &b.value); + let prod = self.add_virtual_nonnative_target::(); + let modulus = self.constant_biguint(&FF::order()); + let overflow = self.add_virtual_biguint_target( + a.value.num_limbs() + b.value.num_limbs() - modulus.num_limbs(), + ); - self.reduce(&result) + self.add_simple_generator(NonNativeMultiplicationGenerator:: { + a: a.clone(), + b: b.clone(), + prod: prod.clone(), + overflow: overflow.clone(), + _phantom: PhantomData, + }); + + self.range_check_u32(prod.value.limbs.clone()); + self.range_check_u32(overflow.limbs.clone()); + + let prod_expected = self.mul_biguint(&a.value, &b.value); + + let mod_times_overflow = self.mul_biguint(&modulus, &overflow); + let prod_actual = self.add_biguint(&prod.value, &mod_times_overflow); + self.connect_biguint(&prod_expected, &prod_actual); + + prod + } + + pub fn mul_many_nonnative( + &mut self, + to_mul: &[NonNativeTarget], + ) -> NonNativeTarget { + if to_mul.len() == 1 { + return to_mul[0].clone(); + } + + let mut accumulator = self.mul_nonnative(&to_mul[0], &to_mul[1]); + for i in 2..to_mul.len() { + accumulator = self.mul_nonnative(&accumulator, &to_mul[i]); + } + accumulator } pub fn neg_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { @@ -104,36 +233,27 @@ impl, const D: usize> CircuitBuilder { pub fn inv_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { let num_limbs = x.value.num_limbs(); let inv_biguint = self.add_virtual_biguint_target(num_limbs); - let inv = NonNativeTarget:: { - value: inv_biguint, - _phantom: PhantomData, - }; + let div = self.add_virtual_biguint_target(num_limbs); self.add_simple_generator(NonNativeInverseGenerator:: { x: x.clone(), - inv: inv.clone(), + inv: inv_biguint.clone(), + div: div.clone(), _phantom: PhantomData, }); - let product = self.mul_nonnative(x, &inv); - let one = self.constant_nonnative(FF::ONE); - self.connect_nonnative(&product, &one); + let product = self.mul_biguint(&x.value, &inv_biguint); - inv - } + let modulus = self.constant_biguint(&FF::order()); + let mod_times_div = self.mul_biguint(&modulus, &div); + let one = self.constant_biguint(&BigUint::one()); + let expected_product = self.add_biguint(&mod_times_div, &one); + self.connect_biguint(&product, &expected_product); - pub fn div_rem_nonnative( - &mut self, - x: &NonNativeTarget, - y: &NonNativeTarget, - ) -> (NonNativeTarget, NonNativeTarget) { - let x_biguint = self.nonnative_to_biguint(x); - let y_biguint = self.nonnative_to_biguint(y); - - let (div_biguint, rem_biguint) = self.div_rem_biguint(&x_biguint, &y_biguint); - let div = self.biguint_to_nonnative(&div_biguint); - let rem = self.biguint_to_nonnative(&rem_biguint); - (div, rem) + NonNativeTarget:: { + value: inv_biguint, + _phantom: PhantomData, + } } /// Returns `x % |FF|` as a `NonNativeTarget`. @@ -148,8 +268,7 @@ impl, const D: usize> CircuitBuilder { } } - #[allow(dead_code)] - fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { + pub fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { let x_biguint = self.nonnative_to_biguint(x); self.reduce(&x_biguint) } @@ -187,10 +306,174 @@ impl, const D: usize> CircuitBuilder { } } +#[derive(Debug)] +struct NonNativeAdditionGenerator, const D: usize, FF: Field> { + a: NonNativeTarget, + b: NonNativeTarget, + sum: NonNativeTarget, + overflow: BoolTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: Field> SimpleGenerator + for NonNativeAdditionGenerator +{ + fn dependencies(&self) -> Vec { + self.a + .value + .limbs + .iter() + .cloned() + .chain(self.b.value.limbs.clone()) + .map(|l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = witness.get_nonnative_target(self.a.clone()); + let b = witness.get_nonnative_target(self.b.clone()); + let a_biguint = a.to_biguint(); + let b_biguint = b.to_biguint(); + let sum_biguint = a_biguint + b_biguint; + let modulus = FF::order(); + let (overflow, sum_reduced) = if sum_biguint > modulus { + (true, sum_biguint - modulus) + } else { + (false, sum_biguint) + }; + + out_buffer.set_biguint_target(self.sum.value.clone(), sum_reduced); + out_buffer.set_bool_target(self.overflow, overflow); + } +} + +#[derive(Debug)] +struct NonNativeMultipleAddsGenerator, const D: usize, FF: Field> { + summands: Vec>, + sum: NonNativeTarget, + overflow: U32Target, + _phantom: PhantomData, +} + +impl, const D: usize, FF: Field> SimpleGenerator + for NonNativeMultipleAddsGenerator +{ + fn dependencies(&self) -> Vec { + self.summands + .iter() + .flat_map(|summand| summand.value.limbs.iter().map(|limb| limb.0)) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let summands: Vec<_> = self + .summands + .iter() + .map(|summand| witness.get_nonnative_target(summand.clone())) + .collect(); + let summand_biguints: Vec<_> = summands + .iter() + .map(|summand| summand.to_biguint()) + .collect(); + + let sum_biguint = summand_biguints + .iter() + .fold(BigUint::zero(), |a, b| a + b.clone()); + + let modulus = FF::order(); + let (overflow_biguint, sum_reduced) = sum_biguint.div_rem(&modulus); + let overflow = overflow_biguint.to_u64_digits()[0] as u32; + + out_buffer.set_biguint_target(self.sum.value.clone(), sum_reduced); + out_buffer.set_u32_target(self.overflow, overflow); + } +} + +#[derive(Debug)] +struct NonNativeSubtractionGenerator, const D: usize, FF: Field> { + a: NonNativeTarget, + b: NonNativeTarget, + diff: NonNativeTarget, + overflow: BoolTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: Field> SimpleGenerator + for NonNativeSubtractionGenerator +{ + fn dependencies(&self) -> Vec { + self.a + .value + .limbs + .iter() + .cloned() + .chain(self.b.value.limbs.clone()) + .map(|l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = witness.get_nonnative_target(self.a.clone()); + let b = witness.get_nonnative_target(self.b.clone()); + let a_biguint = a.to_biguint(); + let b_biguint = b.to_biguint(); + + let modulus = FF::order(); + let (diff_biguint, overflow) = if a_biguint > b_biguint { + (a_biguint - b_biguint, false) + } else { + (modulus + a_biguint - b_biguint, true) + }; + + out_buffer.set_biguint_target(self.diff.value.clone(), diff_biguint); + out_buffer.set_bool_target(self.overflow, overflow); + } +} + +#[derive(Debug)] +struct NonNativeMultiplicationGenerator, const D: usize, FF: Field> { + a: NonNativeTarget, + b: NonNativeTarget, + prod: NonNativeTarget, + overflow: BigUintTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: Field> SimpleGenerator + for NonNativeMultiplicationGenerator +{ + fn dependencies(&self) -> Vec { + self.a + .value + .limbs + .iter() + .cloned() + .chain(self.b.value.limbs.clone()) + .map(|l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = witness.get_nonnative_target(self.a.clone()); + let b = witness.get_nonnative_target(self.b.clone()); + let a_biguint = a.to_biguint(); + let b_biguint = b.to_biguint(); + + let prod_biguint = a_biguint * b_biguint; + + let modulus = FF::order(); + let (overflow_biguint, prod_reduced) = prod_biguint.div_rem(&modulus); + + out_buffer.set_biguint_target(self.prod.value.clone(), prod_reduced); + out_buffer.set_biguint_target(self.overflow.clone(), overflow_biguint); + } +} + #[derive(Debug)] struct NonNativeInverseGenerator, const D: usize, FF: Field> { x: NonNativeTarget, - inv: NonNativeTarget, + inv: BigUintTarget, + div: BigUintTarget, _phantom: PhantomData, } @@ -205,7 +488,14 @@ impl, const D: usize, FF: Field> SimpleGenerator let x = witness.get_nonnative_target(self.x.clone()); let inv = x.inverse(); - out_buffer.set_nonnative_target(self.inv.clone(), inv); + let x_biguint = x.to_biguint(); + let inv_biguint = inv.to_biguint(); + let prod = x_biguint * &inv_biguint; + let modulus = FF::order(); + let (div, _rem) = prod.div_rem(&modulus); + + out_buffer.set_biguint_target(self.div.clone(), div); + out_buffer.set_biguint_target(self.inv.clone(), inv_biguint); } } @@ -227,11 +517,12 @@ mod tests { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; + let x_ff = FF::rand(); let y_ff = FF::rand(); let sum_ff = x_ff + y_ff; - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -247,12 +538,53 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + #[test] + fn test_nonnative_many_adds() -> Result<()> { + type FF = Secp256K1Base; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let a_ff = FF::rand(); + let b_ff = FF::rand(); + let c_ff = FF::rand(); + let d_ff = FF::rand(); + let e_ff = FF::rand(); + let f_ff = FF::rand(); + let g_ff = FF::rand(); + let h_ff = FF::rand(); + let sum_ff = a_ff + b_ff + c_ff + d_ff + e_ff + f_ff + g_ff + h_ff; + + let config = CircuitConfig::standard_ecc_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let a = builder.constant_nonnative(a_ff); + let b = builder.constant_nonnative(b_ff); + let c = builder.constant_nonnative(c_ff); + let d = builder.constant_nonnative(d_ff); + let e = builder.constant_nonnative(e_ff); + let f = builder.constant_nonnative(f_ff); + let g = builder.constant_nonnative(g_ff); + let h = builder.constant_nonnative(h_ff); + let all = [a, b, c, d, e, f, g, h]; + let sum = builder.add_many_nonnative(&all); + + let sum_expected = builder.constant_nonnative(sum_ff); + builder.connect_nonnative(&sum, &sum_expected); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + #[test] fn test_nonnative_sub() -> Result<()> { type FF = Secp256K1Base; const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; + let x_ff = FF::rand(); let mut y_ff = FF::rand(); while y_ff.to_biguint() > x_ff.to_biguint() { @@ -260,7 +592,7 @@ mod tests { } let diff_ff = x_ff - y_ff; - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -286,7 +618,7 @@ mod tests { let y_ff = FF::rand(); let product_ff = x_ff * y_ff; - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -311,7 +643,7 @@ mod tests { let x_ff = FF::rand(); let neg_x_ff = -x_ff; - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); @@ -335,7 +667,7 @@ mod tests { let x_ff = FF::rand(); let inv_x_ff = x_ff.inverse(); - let config = CircuitConfig::standard_recursion_config(); + let config = CircuitConfig::standard_ecc_config(); let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); diff --git a/plonky2/src/gadgets/range_check.rs b/plonky2/src/gadgets/range_check.rs index f8ada106..0776fc68 100644 --- a/plonky2/src/gadgets/range_check.rs +++ b/plonky2/src/gadgets/range_check.rs @@ -1,5 +1,7 @@ use plonky2_field::extension_field::Extendable; +use crate::gadgets::arithmetic_u32::U32Target; +use crate::gates::range_check_u32::U32RangeCheckGate; use crate::hash::hash_types::RichField; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::{BoolTarget, Target}; @@ -41,6 +43,25 @@ impl, const D: usize> CircuitBuilder { (low, high) } + + pub fn range_check_u32(&mut self, vals: Vec) { + let num_input_limbs = vals.len(); + let gate = U32RangeCheckGate::::new(num_input_limbs); + let gate_index = self.add_gate(gate, vec![]); + + for i in 0..num_input_limbs { + self.connect( + Target::wire(gate_index, gate.wire_ith_input_limb(i)), + vals[i].0, + ); + } + } + + pub fn assert_bool(&mut self, b: BoolTarget) { + let z = self.mul_sub(b.target, b.target, b.target); + let zero = self.zero(); + self.connect(z, zero); + } } #[derive(Debug)] diff --git a/plonky2/src/gates/add_many_u32.rs b/plonky2/src/gates/add_many_u32.rs new file mode 100644 index 00000000..4f9c4293 --- /dev/null +++ b/plonky2/src/gates/add_many_u32.rs @@ -0,0 +1,461 @@ +use std::marker::PhantomData; + +use itertools::unfold; +use plonky2_util::ceil_div_usize; + +use crate::field::extension_field::Extendable; +use crate::field::field_types::Field; +use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; +use crate::hash::hash_types::RichField; +use crate::iop::ext_target::ExtensionTarget; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::wire::Wire; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::CircuitConfig; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +const LOG2_MAX_NUM_ADDENDS: usize = 4; +const MAX_NUM_ADDENDS: usize = 16; + +/// A gate to perform addition on `num_addends` different 32-bit values, plus a small carry +#[derive(Copy, Clone, Debug)] +pub struct U32AddManyGate, const D: usize> { + pub num_addends: usize, + pub num_ops: usize, + _phantom: PhantomData, +} + +impl, const D: usize> U32AddManyGate { + pub fn new_from_config(config: &CircuitConfig, num_addends: usize) -> Self { + Self { + num_addends, + num_ops: Self::num_ops(num_addends, config), + _phantom: PhantomData, + } + } + + pub(crate) fn num_ops(num_addends: usize, config: &CircuitConfig) -> usize { + debug_assert!(num_addends <= MAX_NUM_ADDENDS); + let wires_per_op = (num_addends + 3) + Self::num_limbs(); + let routed_wires_per_op = num_addends + 3; + (config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op) + } + + pub fn wire_ith_op_jth_addend(&self, i: usize, j: usize) -> usize { + debug_assert!(i < self.num_ops); + debug_assert!(j < self.num_addends); + (self.num_addends + 3) * i + j + } + pub fn wire_ith_carry(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + (self.num_addends + 3) * i + self.num_addends + } + + pub fn wire_ith_output_result(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + (self.num_addends + 3) * i + self.num_addends + 1 + } + pub fn wire_ith_output_carry(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + (self.num_addends + 3) * i + self.num_addends + 2 + } + + pub fn limb_bits() -> usize { + 2 + } + pub fn num_result_limbs() -> usize { + ceil_div_usize(32, Self::limb_bits()) + } + pub fn num_carry_limbs() -> usize { + ceil_div_usize(LOG2_MAX_NUM_ADDENDS, Self::limb_bits()) + } + pub fn num_limbs() -> usize { + Self::num_result_limbs() + Self::num_carry_limbs() + } + + pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize { + debug_assert!(i < self.num_ops); + debug_assert!(j < Self::num_limbs()); + (self.num_addends + 3) * self.num_ops + Self::num_limbs() * i + j + } +} + +impl, const D: usize> Gate for U32AddManyGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + for i in 0..self.num_ops { + let addends: Vec = (0..self.num_addends) + .map(|j| vars.local_wires[self.wire_ith_op_jth_addend(i, j)]) + .collect(); + let carry = vars.local_wires[self.wire_ith_carry(i)]; + + let computed_output = addends.iter().fold(F::Extension::ZERO, |x, &y| x + y) + carry; + + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_carry = vars.local_wires[self.wire_ith_output_carry(i)]; + + let base = F::Extension::from_canonical_u64(1 << 32u64); + let combined_output = output_carry * base + output_result; + + constraints.push(combined_output - computed_output); + + let mut combined_result_limbs = F::Extension::ZERO; + let mut combined_carry_limbs = F::Extension::ZERO; + let base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let product = (0..max_limb) + .map(|x| this_limb - F::Extension::from_canonical_usize(x)) + .product(); + constraints.push(product); + + if j < Self::num_result_limbs() { + combined_result_limbs = base * combined_result_limbs + this_limb; + } else { + combined_carry_limbs = base * combined_carry_limbs + this_limb; + } + } + constraints.push(combined_result_limbs - output_result); + constraints.push(combined_carry_limbs - output_carry); + } + + constraints + } + + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { + for i in 0..self.num_ops { + let addends: Vec = (0..self.num_addends) + .map(|j| vars.local_wires[self.wire_ith_op_jth_addend(i, j)]) + .collect(); + let carry = vars.local_wires[self.wire_ith_carry(i)]; + + let computed_output = addends.iter().fold(F::ZERO, |x, &y| x + y) + carry; + + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_carry = vars.local_wires[self.wire_ith_output_carry(i)]; + + let base = F::from_canonical_u64(1 << 32u64); + let combined_output = output_carry * base + output_result; + + yield_constr.one(combined_output - computed_output); + + let mut combined_result_limbs = F::ZERO; + let mut combined_carry_limbs = F::ZERO; + let base = F::from_canonical_u64(1u64 << Self::limb_bits()); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let product = (0..max_limb) + .map(|x| this_limb - F::from_canonical_usize(x)) + .product(); + yield_constr.one(product); + + if j < Self::num_result_limbs() { + combined_result_limbs = base * combined_result_limbs + this_limb; + } else { + combined_carry_limbs = base * combined_carry_limbs + this_limb; + } + } + yield_constr.one(combined_result_limbs - output_result); + yield_constr.one(combined_carry_limbs - output_carry); + } + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + for i in 0..self.num_ops { + let addends: Vec> = (0..self.num_addends) + .map(|j| vars.local_wires[self.wire_ith_op_jth_addend(i, j)]) + .collect(); + let carry = vars.local_wires[self.wire_ith_carry(i)]; + + let mut computed_output = carry; + for addend in addends { + computed_output = builder.add_extension(computed_output, addend); + } + + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_carry = vars.local_wires[self.wire_ith_output_carry(i)]; + + let base: F::Extension = F::from_canonical_u64(1 << 32u64).into(); + let base_target = builder.constant_extension(base); + let combined_output = + builder.mul_add_extension(output_carry, base_target, output_result); + + constraints.push(builder.sub_extension(combined_output, computed_output)); + + let mut combined_result_limbs = builder.zero_extension(); + let mut combined_carry_limbs = builder.zero_extension(); + let base = builder + .constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + + let mut product = builder.one_extension(); + for x in 0..max_limb { + let x_target = + builder.constant_extension(F::Extension::from_canonical_usize(x)); + let diff = builder.sub_extension(this_limb, x_target); + product = builder.mul_extension(product, diff); + } + constraints.push(product); + + if j < Self::num_result_limbs() { + combined_result_limbs = + builder.mul_add_extension(base, combined_result_limbs, this_limb); + } else { + combined_carry_limbs = + builder.mul_add_extension(base, combined_carry_limbs, this_limb); + } + } + constraints.push(builder.sub_extension(combined_result_limbs, output_result)); + constraints.push(builder.sub_extension(combined_carry_limbs, output_carry)); + } + + constraints + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + (0..self.num_ops) + .map(|i| { + let g: Box> = Box::new( + U32AddManyGenerator { + gate: *self, + gate_index, + i, + _phantom: PhantomData, + } + .adapter(), + ); + g + }) + .collect() + } + + fn num_wires(&self) -> usize { + (self.num_addends + 3) * self.num_ops + Self::num_limbs() * self.num_ops + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 1 << Self::limb_bits() + } + + fn num_constraints(&self) -> usize { + self.num_ops * (3 + Self::num_limbs()) + } +} + +#[derive(Clone, Debug)] +struct U32AddManyGenerator, const D: usize> { + gate: U32AddManyGate, + gate_index: usize, + i: usize, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for U32AddManyGenerator +{ + fn dependencies(&self) -> Vec { + let local_target = |input| Target::wire(self.gate_index, input); + + (0..self.gate.num_addends) + .map(|j| local_target(self.gate.wire_ith_op_jth_addend(self.i, j))) + .chain([local_target(self.gate.wire_ith_carry(self.i))]) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |input| Wire { + gate: self.gate_index, + input, + }; + + let get_local_wire = |input| witness.get_wire(local_wire(input)); + + let addends: Vec<_> = (0..self.gate.num_addends) + .map(|j| get_local_wire(self.gate.wire_ith_op_jth_addend(self.i, j))) + .collect(); + let carry = get_local_wire(self.gate.wire_ith_carry(self.i)); + + let output = addends.iter().fold(F::ZERO, |x, &y| x + y) + carry; + let output_u64 = output.to_canonical_u64(); + + let output_carry_u64 = output_u64 >> 32; + let output_result_u64 = output_u64 & ((1 << 32) - 1); + + let output_carry = F::from_canonical_u64(output_carry_u64); + let output_result = F::from_canonical_u64(output_result_u64); + + let output_carry_wire = local_wire(self.gate.wire_ith_output_carry(self.i)); + let output_result_wire = local_wire(self.gate.wire_ith_output_result(self.i)); + + out_buffer.set_wire(output_carry_wire, output_carry); + out_buffer.set_wire(output_result_wire, output_result); + + let num_result_limbs = U32AddManyGate::::num_result_limbs(); + let num_carry_limbs = U32AddManyGate::::num_carry_limbs(); + let limb_base = 1 << U32AddManyGate::::limb_bits(); + + let split_to_limbs = |mut val, num| { + unfold((), move |_| { + let ret = val % limb_base; + val /= limb_base; + Some(ret) + }) + .take(num) + .map(F::from_canonical_u64) + }; + + let result_limbs = split_to_limbs(output_result_u64, num_result_limbs); + let carry_limbs = split_to_limbs(output_carry_u64, num_carry_limbs); + + for (j, limb) in result_limbs.chain(carry_limbs).enumerate() { + let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j)); + out_buffer.set_wire(wire, limb); + } + } +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomData; + + use anyhow::Result; + use itertools::unfold; + use rand::Rng; + + use crate::field::extension_field::quartic::QuarticExtension; + use crate::field::field_types::Field; + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::add_many_u32::U32AddManyGate; + use crate::gates::gate::Gate; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::hash::hash_types::HashOut; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use crate::plonk::vars::EvaluationVars; + + #[test] + fn low_degree() { + test_low_degree::(U32AddManyGate:: { + num_addends: 4, + num_ops: 3, + _phantom: PhantomData, + }) + } + + #[test] + fn eval_fns() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(U32AddManyGate:: { + num_addends: 4, + num_ops: 3, + _phantom: PhantomData, + }) + } + + #[test] + fn test_gate_constraint() { + type F = GoldilocksField; + type FF = QuarticExtension; + const D: usize = 4; + const NUM_ADDENDS: usize = 10; + const NUM_U32_ADD_MANY_OPS: usize = 3; + + fn get_wires(addends: Vec>, carries: Vec) -> Vec { + let mut v0 = Vec::new(); + let mut v1 = Vec::new(); + + let num_result_limbs = U32AddManyGate::::num_result_limbs(); + let num_carry_limbs = U32AddManyGate::::num_carry_limbs(); + let limb_base = 1 << U32AddManyGate::::limb_bits(); + for op in 0..NUM_U32_ADD_MANY_OPS { + let adds = &addends[op]; + let ca = carries[op]; + + let output = adds.iter().sum::() + ca; + let output_result = output & ((1 << 32) - 1); + let output_carry = output >> 32; + + let split_to_limbs = |mut val, num| { + unfold((), move |_| { + let ret = val % limb_base; + val /= limb_base; + Some(ret) + }) + .take(num) + .map(F::from_canonical_u64) + }; + + let mut result_limbs: Vec<_> = + split_to_limbs(output_result, num_result_limbs).collect(); + let mut carry_limbs: Vec<_> = + split_to_limbs(output_carry, num_carry_limbs).collect(); + + for a in adds { + v0.push(F::from_canonical_u64(*a)); + } + v0.push(F::from_canonical_u64(ca)); + v0.push(F::from_canonical_u64(output_result)); + v0.push(F::from_canonical_u64(output_carry)); + v1.append(&mut result_limbs); + v1.append(&mut carry_limbs); + } + + v0.iter().chain(v1.iter()).map(|&x| x.into()).collect() + } + + let mut rng = rand::thread_rng(); + let addends: Vec> = (0..NUM_U32_ADD_MANY_OPS) + .map(|_| (0..NUM_ADDENDS).map(|_| rng.gen::() as u64).collect()) + .collect(); + let carries: Vec<_> = (0..NUM_U32_ADD_MANY_OPS) + .map(|_| rng.gen::() as u64) + .collect(); + + let gate = U32AddManyGate:: { + num_addends: NUM_ADDENDS, + num_ops: NUM_U32_ADD_MANY_OPS, + _phantom: PhantomData, + }; + + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(addends, carries), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } +} diff --git a/plonky2/src/gates/arithmetic_base.rs b/plonky2/src/gates/arithmetic_base.rs index 8f67dab2..738b8ad4 100644 --- a/plonky2/src/gates/arithmetic_base.rs +++ b/plonky2/src/gates/arithmetic_base.rs @@ -131,7 +131,7 @@ impl, const D: usize> Gate for ArithmeticGate ); g }) - .collect::>() + .collect() } fn num_wires(&self) -> usize { diff --git a/plonky2/src/gates/arithmetic_extension.rs b/plonky2/src/gates/arithmetic_extension.rs index 93f7277c..def09473 100644 --- a/plonky2/src/gates/arithmetic_extension.rs +++ b/plonky2/src/gates/arithmetic_extension.rs @@ -138,7 +138,7 @@ impl, const D: usize> Gate for ArithmeticExte ); g }) - .collect::>() + .collect() } fn num_wires(&self) -> usize { diff --git a/plonky2/src/gates/arithmetic_u32.rs b/plonky2/src/gates/arithmetic_u32.rs index 1d4a834c..dc03e296 100644 --- a/plonky2/src/gates/arithmetic_u32.rs +++ b/plonky2/src/gates/arithmetic_u32.rs @@ -212,7 +212,7 @@ impl, const D: usize> Gate for U32ArithmeticG ); g }) - .collect::>() + .collect() } fn num_wires(&self) -> usize { @@ -425,10 +425,7 @@ mod tests { v1.append(&mut output_limbs_f); } - v0.iter() - .chain(v1.iter()) - .map(|&x| x.into()) - .collect::>() + v0.iter().chain(v1.iter()).map(|&x| x.into()).collect() } let mut rng = rand::thread_rng(); diff --git a/plonky2/src/gates/assert_le.rs b/plonky2/src/gates/assert_le.rs index 6a99acd9..c385bb31 100644 --- a/plonky2/src/gates/assert_le.rs +++ b/plonky2/src/gates/assert_le.rs @@ -578,7 +578,7 @@ mod tests { v.append(&mut chunks_equal); v.append(&mut intermediate_values); - v.iter().map(|&x| x.into()).collect::>() + v.iter().map(|&x| x.into()).collect() }; let mut rng = rand::thread_rng(); diff --git a/plonky2/src/gates/comparison.rs b/plonky2/src/gates/comparison.rs index 9a119b1d..424ecb5b 100644 --- a/plonky2/src/gates/comparison.rs +++ b/plonky2/src/gates/comparison.rs @@ -658,7 +658,7 @@ mod tests { v.append(&mut intermediate_values); v.append(&mut msd_bits); - v.iter().map(|&x| x.into()).collect::>() + v.iter().map(|&x| x.into()).collect() }; let mut rng = rand::thread_rng(); diff --git a/plonky2/src/gates/gate.rs b/plonky2/src/gates/gate.rs index 9ae2a271..9d21c273 100644 --- a/plonky2/src/gates/gate.rs +++ b/plonky2/src/gates/gate.rs @@ -113,7 +113,7 @@ pub trait Gate, const D: usize>: 'static + Send + S builder: &mut CircuitBuilder, mut vars: EvaluationTargets, prefix: &[bool], - combined_gate_constraints: &mut Vec>, + combined_gate_constraints: &mut [ExtensionTarget], ) { let filter = compute_filter_recursively(builder, prefix, vars.local_constants); vars.remove_prefix(prefix); diff --git a/plonky2/src/gates/interpolation.rs b/plonky2/src/gates/interpolation.rs index 9d3d6e3f..46c42113 100644 --- a/plonky2/src/gates/interpolation.rs +++ b/plonky2/src/gates/interpolation.rs @@ -343,7 +343,7 @@ mod tests { for i in 0..coeffs.len() { v.extend(coeffs.coeffs[i].0); } - v.iter().map(|&x| x.into()).collect::>() + v.iter().map(|&x| x.into()).collect() } // Get a working row for InterpolationGate. diff --git a/plonky2/src/gates/low_degree_interpolation.rs b/plonky2/src/gates/low_degree_interpolation.rs index b7307470..845da5ab 100644 --- a/plonky2/src/gates/low_degree_interpolation.rs +++ b/plonky2/src/gates/low_degree_interpolation.rs @@ -443,7 +443,7 @@ mod tests { .take(gate.num_points() - 2) .flat_map(|ff| ff.0), ); - v.iter().map(|&x| x.into()).collect::>() + v.iter().map(|&x| x.into()).collect() } // Get a working row for LowDegreeInterpolationGate. diff --git a/plonky2/src/gates/mod.rs b/plonky2/src/gates/mod.rs index a3f92615..18e3e99b 100644 --- a/plonky2/src/gates/mod.rs +++ b/plonky2/src/gates/mod.rs @@ -1,6 +1,7 @@ // Gates have `new` methods that return `GateRef`s. #![allow(clippy::new_ret_no_self)] +pub mod add_many_u32; pub mod arithmetic_base; pub mod arithmetic_extension; pub mod arithmetic_u32; @@ -20,6 +21,7 @@ pub mod poseidon; pub(crate) mod poseidon_mds; pub(crate) mod public_input; pub mod random_access; +pub mod range_check_u32; pub mod reducing; pub mod reducing_extension; pub mod subtraction_u32; diff --git a/plonky2/src/gates/multiplication_extension.rs b/plonky2/src/gates/multiplication_extension.rs index 9ccfe637..54629a47 100644 --- a/plonky2/src/gates/multiplication_extension.rs +++ b/plonky2/src/gates/multiplication_extension.rs @@ -125,7 +125,7 @@ impl, const D: usize> Gate for MulExtensionGa ); g }) - .collect::>() + .collect() } fn num_wires(&self) -> usize { diff --git a/plonky2/src/gates/random_access.rs b/plonky2/src/gates/random_access.rs index 77359a19..6379f99f 100644 --- a/plonky2/src/gates/random_access.rs +++ b/plonky2/src/gates/random_access.rs @@ -209,7 +209,7 @@ impl, const D: usize> Gate for RandomAccessGa ); g }) - .collect::>() + .collect() } fn num_wires(&self) -> usize { diff --git a/plonky2/src/gates/range_check_u32.rs b/plonky2/src/gates/range_check_u32.rs new file mode 100644 index 00000000..79e91de8 --- /dev/null +++ b/plonky2/src/gates/range_check_u32.rs @@ -0,0 +1,322 @@ +use std::marker::PhantomData; + +use plonky2_util::ceil_div_usize; + +use crate::field::extension_field::Extendable; +use crate::field::field_types::Field; +use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; +use crate::hash::hash_types::RichField; +use crate::iop::ext_target::ExtensionTarget; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_recursive}; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// A gate which can decompose a number into base B little-endian limbs. +#[derive(Copy, Clone, Debug)] +pub struct U32RangeCheckGate, const D: usize> { + pub num_input_limbs: usize, + _phantom: PhantomData, +} + +impl, const D: usize> U32RangeCheckGate { + pub fn new(num_input_limbs: usize) -> Self { + Self { + num_input_limbs, + _phantom: PhantomData, + } + } + + pub const AUX_LIMB_BITS: usize = 2; + pub const BASE: usize = 1 << Self::AUX_LIMB_BITS; + + fn aux_limbs_per_input_limb(&self) -> usize { + ceil_div_usize(32, Self::AUX_LIMB_BITS) + } + pub fn wire_ith_input_limb(&self, i: usize) -> usize { + debug_assert!(i < self.num_input_limbs); + i + } + pub fn wire_ith_input_limb_jth_aux_limb(&self, i: usize, j: usize) -> usize { + debug_assert!(i < self.num_input_limbs); + debug_assert!(j < self.aux_limbs_per_input_limb()); + self.num_input_limbs + self.aux_limbs_per_input_limb() * i + j + } +} + +impl, const D: usize> Gate for U32RangeCheckGate { + fn id(&self) -> String { + format!("{:?}", self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let base = F::Extension::from_canonical_usize(Self::BASE); + for i in 0..self.num_input_limbs { + let input_limb = vars.local_wires[self.wire_ith_input_limb(i)]; + let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb()) + .map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)]) + .collect(); + let computed_sum = reduce_with_powers(&aux_limbs, base); + + constraints.push(computed_sum - input_limb); + for aux_limb in aux_limbs { + constraints.push( + (0..Self::BASE) + .map(|i| aux_limb - F::Extension::from_canonical_usize(i)) + .product(), + ); + } + } + + constraints + } + + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { + let base = F::from_canonical_usize(Self::BASE); + for i in 0..self.num_input_limbs { + let input_limb = vars.local_wires[self.wire_ith_input_limb(i)]; + let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb()) + .map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)]) + .collect(); + let computed_sum = reduce_with_powers(&aux_limbs, base); + + yield_constr.one(computed_sum - input_limb); + for aux_limb in aux_limbs { + yield_constr.one( + (0..Self::BASE) + .map(|i| aux_limb - F::from_canonical_usize(i)) + .product(), + ); + } + } + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let base = builder.constant(F::from_canonical_usize(Self::BASE)); + for i in 0..self.num_input_limbs { + let input_limb = vars.local_wires[self.wire_ith_input_limb(i)]; + let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb()) + .map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)]) + .collect(); + let computed_sum = reduce_with_powers_ext_recursive(builder, &aux_limbs, base); + + constraints.push(builder.sub_extension(computed_sum, input_limb)); + for aux_limb in aux_limbs { + constraints.push({ + let mut acc = builder.one_extension(); + (0..Self::BASE).for_each(|i| { + // We update our accumulator as: + // acc' = acc (x - i) + // = acc x + (-i) acc + // Since -i is constant, we can do this in one arithmetic_extension call. + let neg_i = -F::from_canonical_usize(i); + acc = builder.arithmetic_extension(F::ONE, neg_i, acc, aux_limb, acc) + }); + acc + }); + } + } + + constraints + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + let gen = U32RangeCheckGenerator { + gate: *self, + gate_index, + }; + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + self.num_input_limbs * (1 + self.aux_limbs_per_input_limb()) + } + + fn num_constants(&self) -> usize { + 0 + } + + // Bounded by the range-check (x-0)*(x-1)*...*(x-BASE+1). + fn degree(&self) -> usize { + Self::BASE + } + + // 1 for checking the each sum of aux limbs, plus a range check for each aux limb. + fn num_constraints(&self) -> usize { + self.num_input_limbs * (1 + self.aux_limbs_per_input_limb()) + } +} + +#[derive(Debug)] +pub struct U32RangeCheckGenerator, const D: usize> { + gate: U32RangeCheckGate, + gate_index: usize, +} + +impl, const D: usize> SimpleGenerator + for U32RangeCheckGenerator +{ + fn dependencies(&self) -> Vec { + let num_input_limbs = self.gate.num_input_limbs; + (0..num_input_limbs) + .map(|i| Target::wire(self.gate_index, self.gate.wire_ith_input_limb(i))) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let num_input_limbs = self.gate.num_input_limbs; + for i in 0..num_input_limbs { + let sum_value = witness + .get_target(Target::wire( + self.gate_index, + self.gate.wire_ith_input_limb(i), + )) + .to_canonical_u64() as u32; + + let base = U32RangeCheckGate::::BASE as u32; + let limbs = (0..self.gate.aux_limbs_per_input_limb()).map(|j| { + Target::wire( + self.gate_index, + self.gate.wire_ith_input_limb_jth_aux_limb(i, j), + ) + }); + let limbs_value = (0..self.gate.aux_limbs_per_input_limb()) + .scan(sum_value, |acc, _| { + let tmp = *acc % base; + *acc /= base; + Some(F::from_canonical_u32(tmp)) + }) + .collect::>(); + + for (b, b_value) in limbs.zip(limbs_value) { + out_buffer.set_target(b, b_value); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::marker::PhantomData; + + use anyhow::Result; + use itertools::unfold; + use plonky2_util::ceil_div_usize; + use rand::Rng; + + use crate::field::extension_field::quartic::QuarticExtension; + use crate::field::field_types::Field; + use crate::field::goldilocks_field::GoldilocksField; + use crate::gates::gate::Gate; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::range_check_u32::U32RangeCheckGate; + use crate::hash::hash_types::HashOut; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use crate::plonk::vars::EvaluationVars; + + #[test] + fn low_degree() { + test_low_degree::(U32RangeCheckGate::new(8)) + } + + #[test] + fn eval_fns() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(U32RangeCheckGate::new(8)) + } + + fn test_gate_constraint(input_limbs: Vec) { + type F = GoldilocksField; + type FF = QuarticExtension; + const D: usize = 4; + const AUX_LIMB_BITS: usize = 2; + const BASE: usize = 1 << AUX_LIMB_BITS; + const AUX_LIMBS_PER_INPUT_LIMB: usize = ceil_div_usize(32, AUX_LIMB_BITS); + + fn get_wires(input_limbs: Vec) -> Vec { + let num_input_limbs = input_limbs.len(); + let mut v = Vec::new(); + + for i in 0..num_input_limbs { + let input_limb = input_limbs[i]; + + let split_to_limbs = |mut val, num| { + unfold((), move |_| { + let ret = val % (BASE as u64); + val /= BASE as u64; + Some(ret) + }) + .take(num) + .map(F::from_canonical_u64) + }; + + let mut aux_limbs: Vec<_> = + split_to_limbs(input_limb, AUX_LIMBS_PER_INPUT_LIMB).collect(); + + v.append(&mut aux_limbs); + } + + input_limbs + .iter() + .cloned() + .map(F::from_canonical_u64) + .chain(v.iter().cloned()) + .map(|x| x.into()) + .collect() + } + + let gate = U32RangeCheckGate:: { + num_input_limbs: 8, + _phantom: PhantomData, + }; + + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(input_limbs), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } + + #[test] + fn test_gate_constraint_good() { + let mut rng = rand::thread_rng(); + let input_limbs: Vec<_> = (0..8).map(|_| rng.gen::() as u64).collect(); + + test_gate_constraint(input_limbs); + } + + #[test] + #[should_panic] + fn test_gate_constraint_bad() { + let mut rng = rand::thread_rng(); + let input_limbs: Vec<_> = (0..8).map(|_| rng.gen()).collect(); + + test_gate_constraint(input_limbs); + } +} diff --git a/plonky2/src/gates/subtraction_u32.rs b/plonky2/src/gates/subtraction_u32.rs index fa817ce4..f083db5a 100644 --- a/plonky2/src/gates/subtraction_u32.rs +++ b/plonky2/src/gates/subtraction_u32.rs @@ -416,10 +416,7 @@ mod tests { v1.append(&mut output_limbs); } - v0.iter() - .chain(v1.iter()) - .map(|&x| x.into()) - .collect::>() + v0.iter().chain(v1.iter()).map(|&x| x.into()).collect() } let mut rng = rand::thread_rng(); diff --git a/plonky2/src/gates/switch.rs b/plonky2/src/gates/switch.rs index c5271a60..583f9a6e 100644 --- a/plonky2/src/gates/switch.rs +++ b/plonky2/src/gates/switch.rs @@ -432,7 +432,7 @@ mod tests { v.push(F::from_bool(switch)); } - v.iter().map(|&x| x.into()).collect::>() + v.iter().map(|&x| x.into()).collect() } let first_inputs: Vec> = (0..num_copies).map(|_| F::rand_vec(CHUNK_SIZE)).collect(); diff --git a/plonky2/src/iop/generator.rs b/plonky2/src/iop/generator.rs index 368232fd..73978f5c 100644 --- a/plonky2/src/iop/generator.rs +++ b/plonky2/src/iop/generator.rs @@ -10,7 +10,7 @@ use crate::gadgets::biguint::BigUintTarget; use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::{HashOut, HashOutTarget, RichField}; use crate::iop::ext_target::ExtensionTarget; -use crate::iop::target::Target; +use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; use crate::iop::witness::{PartialWitness, PartitionWitness, Witness}; use crate::plonk::circuit_data::{CommonCircuitData, ProverOnlyCircuitData}; @@ -161,12 +161,17 @@ impl GeneratedValues { self.target_values.push((target, value)) } - fn set_u32_target(&mut self, target: U32Target, value: u32) { + pub fn set_bool_target(&mut self, target: BoolTarget, value: bool) { + self.set_target(target.target, F::from_bool(value)) + } + + pub fn set_u32_target(&mut self, target: U32Target, value: u32) { self.set_target(target.0, F::from_canonical_u32(value)) } pub fn set_biguint_target(&mut self, target: BigUintTarget, value: BigUint) { let mut limbs = value.to_u32_digits(); + assert!(target.num_limbs() >= limbs.len()); limbs.resize(target.num_limbs(), 0); diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index d9bcc1cf..ad216d69 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -16,6 +16,7 @@ use crate::gadgets::arithmetic::BaseArithmeticOperation; use crate::gadgets::arithmetic_extension::ExtensionArithmeticOperation; use crate::gadgets::arithmetic_u32::U32Target; use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; +use crate::gates::add_many_u32::U32AddManyGate; use crate::gates::arithmetic_base::ArithmeticGate; use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::arithmetic_u32::U32ArithmeticGate; @@ -203,6 +204,12 @@ impl, const D: usize> CircuitBuilder { BoolTarget::new_unsafe(self.add_virtual_target()) } + pub fn add_virtual_bool_target_safe(&mut self) -> BoolTarget { + let b = BoolTarget::new_unsafe(self.add_virtual_target()); + self.assert_bool(b); + b + } + /// Adds a gate to the circuit, and returns its index. pub fn add_gate>(&mut self, gate_type: G, constants: Vec) -> usize { self.check_gate_compatibility(&gate_type); @@ -233,7 +240,7 @@ impl, const D: usize> CircuitBuilder { fn check_gate_compatibility>(&self, gate: &G) { assert!( gate.num_wires() <= self.config.num_wires, - "{:?} requires {} wires, but our GateConfig has only {}", + "{:?} requires {} wires, but our CircuitConfig has only {}", gate.id(), gate.num_wires(), self.config.num_wires @@ -654,14 +661,14 @@ impl, const D: usize> CircuitBuilder { let subgroup = F::two_adic_subgroup(degree_bits); let constant_vecs = timed!( - &mut timing, + timing, "generate constant polynomials", self.constant_polys(&prefixed_gates, num_constants) ); let k_is = get_unique_coset_shifts(degree, self.config.num_routed_wires); let (sigma_vecs, forest) = timed!( - &mut timing, + timing, "generate sigma polynomials", self.sigma_vecs(&k_is, &subgroup) ); @@ -816,9 +823,12 @@ pub struct BatchedGates, const D: usize> { /// of switches pub(crate) current_switch_gates: Vec, usize, usize)>>, + /// A map `n -> (g, i)` from `n` number of addends to an available `U32AddManyGate` of that size with gate + /// index `g` and already using `i` random accesses. + pub(crate) free_u32_add_many: HashMap, + /// The `U32ArithmeticGate` currently being filled (so new u32 arithmetic operations will be added to this gate before creating a new one) pub(crate) current_u32_arithmetic_gate: Option<(usize, usize)>, - /// The `U32SubtractionGate` currently being filled (so new u32 subtraction operations will be added to this gate before creating a new one) pub(crate) current_u32_subtraction_gate: Option<(usize, usize)>, @@ -834,6 +844,7 @@ impl, const D: usize> BatchedGates { free_mul: HashMap::new(), free_random_access: HashMap::new(), current_switch_gates: Vec::new(), + free_u32_add_many: HashMap::new(), current_u32_arithmetic_gate: None, current_u32_subtraction_gate: None, free_constant: None, @@ -931,8 +942,8 @@ impl, const D: usize> CircuitBuilder { (gate, i) } - /// Finds the last available random access gate with the given `vec_size` or add one if there aren't any. - /// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index + /// Finds the last available random access gate with the given `bits` or adds one if there aren't any. + /// Returns `(g,i)` such that there is a random access gate for the given `bits` at index /// `g` and the gate's `i`-th random access is available. pub(crate) fn find_random_access_gate(&mut self, bits: usize) -> (usize, usize) { let (gate, i) = self @@ -994,6 +1005,35 @@ impl, const D: usize> CircuitBuilder { (gate, gate_index, next_copy) } + /// Finds the last available U32 add-many gate with the given `num_addends` or adds one if there aren't any. + /// Returns `(g,i)` such that there is a `U32AddManyGate` for the given `num_addends` at index + /// `g` and the gate's `i`-th copy is available. + pub(crate) fn find_u32_add_many_gate(&mut self, num_addends: usize) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_u32_add_many + .get(&num_addends) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + U32AddManyGate::new_from_config(&self.config, num_addends), + vec![], + ); + (gate, 0) + }); + + // Update `free_u32_add_many` with new values. + if i + 1 < U32AddManyGate::::new_from_config(&self.config, num_addends).num_ops { + self.batched_gates + .free_u32_add_many + .insert(num_addends, (gate, i + 1)); + } else { + self.batched_gates.free_u32_add_many.remove(&num_addends); + } + + (gate, i) + } + pub(crate) fn find_u32_arithmetic_gate(&mut self) -> (usize, usize) { let (gate_index, copy) = match self.batched_gates.current_u32_arithmetic_gate { None => { @@ -1140,6 +1180,28 @@ impl, const D: usize> CircuitBuilder { } } + /// Fill the remaining unused u32 add-many operations with zeros, so that all + /// `U32AddManyGenerator`s are run. + fn fill_u32_add_many_gates(&mut self) { + let zero = self.zero_u32(); + for (num_addends, (_, i)) in self.batched_gates.free_u32_add_many.clone() { + let max_copies = + U32AddManyGate::::new_from_config(&self.config, num_addends).num_ops; + for _ in i..max_copies { + let gate = U32AddManyGate::::new_from_config(&self.config, num_addends); + let (gate_index, copy) = self.find_u32_add_many_gate(num_addends); + + for j in 0..num_addends { + self.connect( + Target::wire(gate_index, gate.wire_ith_op_jth_addend(copy, j)), + zero.0, + ); + } + self.connect(Target::wire(gate_index, gate.wire_ith_carry(copy)), zero.0); + } + } + } + /// Fill the remaining unused U32 arithmetic operations with zeros, so that all /// `U32ArithmeticGenerator`s are run. fn fill_u32_arithmetic_gates(&mut self) { @@ -1172,6 +1234,7 @@ impl, const D: usize> CircuitBuilder { self.fill_mul_gates(); self.fill_random_access_gates(); self.fill_switch_gates(); + self.fill_u32_add_many_gates(); self.fill_u32_arithmetic_gates(); self.fill_u32_subtraction_gates(); } diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index dd7ebc25..7e667b8d 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -49,7 +49,7 @@ pub struct CircuitConfig { impl Default for CircuitConfig { fn default() -> Self { - CircuitConfig::standard_recursion_config() + Self::standard_recursion_config() } } @@ -79,6 +79,13 @@ impl CircuitConfig { } } + pub fn standard_ecc_config() -> Self { + Self { + num_wires: 136, + ..Self::standard_recursion_config() + } + } + pub fn standard_recursion_zk_config() -> Self { CircuitConfig { zero_knowledge: true,