From 08fa4031badd0ec42355404ed70ee6b4886c8f5f Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Thu, 20 Jan 2022 16:06:00 -0800 Subject: [PATCH] ECDSA merge --- plonky2/src/gadgets/arithmetic_u32.rs | 82 +++++++ {waksman/src => plonky2/src/gadgets}/ecdsa.rs | 0 plonky2/src/gadgets/mod.rs | 1 + plonky2/src/gadgets/multiple_comparison.rs | 9 +- plonky2/src/gadgets/nonnative.rs | 206 ++++++++++++++++-- plonky2/src/gates/mod.rs | 2 + plonky2/src/iop/generator.rs | 7 +- plonky2/src/plonk/circuit_builder.rs | 82 ++++++- 8 files changed, 370 insertions(+), 19 deletions(-) rename {waksman/src => plonky2/src/gadgets}/ecdsa.rs (100%) diff --git a/plonky2/src/gadgets/arithmetic_u32.rs b/plonky2/src/gadgets/arithmetic_u32.rs index 6116f61b..0fbd076f 100644 --- a/plonky2/src/gadgets/arithmetic_u32.rs +++ b/plonky2/src/gadgets/arithmetic_u32.rs @@ -4,6 +4,7 @@ use crate::gates::arithmetic_u32::U32ArithmeticGate; use crate::gates::subtraction_u32::U32SubtractionGate; use crate::hash::hash_types::RichField; use crate::iop::target::Target; +use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; #[derive(Clone, Copy, Debug)] @@ -152,4 +153,85 @@ impl, const D: usize> CircuitBuilder { (output_result, output_borrow) } + + pub fn split_to_u32(&mut self, x: Target) -> (U32Target, U32Target) { + let low = self.add_virtual_u32_target(); + let high = self.add_virtual_u32_target(); + + let base = self.constant(F::from_canonical_u64(1u64 << 32)); + let combined = self.mul_add(high.0, base, low.0); + self.connect(x, combined); + + self.add_simple_generator(SplitToU32Generator:: { + x: x.clone(), + low: low.clone(), + high: high.clone(), + _phantom: PhantomData, + }); + + (low, high) + } } + +#[derive(Debug)] +struct SplitToU32Generator, const D: usize> { + x: Target, + low: U32Target, + high: U32Target, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for SplitToU32Generator +{ + fn dependencies(&self) -> Vec { + vec![self.x] + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let x = witness.get_target(self.x.clone()); + let x_u64 = x.to_canonical_u64(); + let low = x_u64 as u32; + let high: u32 = (x_u64 >> 32).try_into().unwrap(); + println!("LOW: {}", low); + println!("HIGH: {}", high); + + out_buffer.set_u32_target(self.low.clone(), low); + out_buffer.set_u32_target(self.high.clone(), high); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use rand::{thread_rng, Rng}; + + use crate::field::goldilocks_field::GoldilocksField; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::verifier::verify; + + #[test] + pub fn test_add_many_u32s() -> Result<()> { + type F = GoldilocksField; + const D: usize = 4; + + let config = CircuitConfig::standard_recursion_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let mut rng = thread_rng(); + let mut to_add = Vec::new(); + for _ in 0..10 { + to_add.push(builder.constant_u32(rng.gen())); + } + let _ = builder.add_many_u32(&to_add); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } +} \ No newline at end of file diff --git a/waksman/src/ecdsa.rs b/plonky2/src/gadgets/ecdsa.rs similarity index 100% rename from waksman/src/ecdsa.rs rename to plonky2/src/gadgets/ecdsa.rs diff --git a/plonky2/src/gadgets/mod.rs b/plonky2/src/gadgets/mod.rs index ec4d1263..5dacdb51 100644 --- a/plonky2/src/gadgets/mod.rs +++ b/plonky2/src/gadgets/mod.rs @@ -2,6 +2,7 @@ pub mod arithmetic; pub mod arithmetic_extension; pub mod arithmetic_u32; pub mod biguint; +pub mod binary_arithmetic; pub mod curve; pub mod ecdsa; pub mod hash; diff --git a/plonky2/src/gadgets/multiple_comparison.rs b/plonky2/src/gadgets/multiple_comparison.rs index 88b94f3f..70afcab5 100644 --- a/plonky2/src/gadgets/multiple_comparison.rs +++ b/plonky2/src/gadgets/multiple_comparison.rs @@ -60,8 +60,13 @@ impl, const D: usize> CircuitBuilder { /// Helper function for comparing, specifically, lists of `U32Target`s. pub fn list_le_u32(&mut self, a: Vec, b: Vec) -> BoolTarget { - let a_targets = a.iter().map(|&t| t.0).collect(); - let b_targets = b.iter().map(|&t| t.0).collect(); + // let a_targets = a.iter().map(|&t| t.0).collect(); + // let b_targets = b.iter().map(|&t| t.0).collect(); + // self.list_le(a_targets, b_targets, 32) + + let num = a.len() / 2; + let a_targets = self.add_virtual_targets(num); + let b_targets = self.add_virtual_targets(num); self.list_le(a_targets, b_targets, 32) } } diff --git a/plonky2/src/gadgets/nonnative.rs b/plonky2/src/gadgets/nonnative.rs index 824d851c..cca22fd1 100644 --- a/plonky2/src/gadgets/nonnative.rs +++ b/plonky2/src/gadgets/nonnative.rs @@ -39,6 +39,10 @@ impl, const D: usize> CircuitBuilder { self.biguint_to_nonnative(&x_biguint) } + pub fn zero_nonnative(&mut self) -> NonNativeTarget { + self.constant_nonnative(FF::ZERO) + } + // Assert that two NonNativeTarget's, both assumed to be in reduced form, are equal. pub fn connect_nonnative( &mut self, @@ -70,6 +74,22 @@ impl, const D: usize> CircuitBuilder { self.reduce(&result) } + pub fn add_many_nonnative( + &mut self, + to_add: &[NonNativeTarget], + ) -> NonNativeTarget { + if to_add.len() == 1 { + return to_add[0].clone(); + } + + let mut result = self.add_biguint(&to_add[0].value, &to_add[1].value); + for i in 2..to_add.len() { + result = self.add_biguint(&result, &to_add[i].value); + } + + self.reduce(&result) + } + // Subtract two `NonNativeTarget`s. pub fn sub_nonnative( &mut self, @@ -94,6 +114,22 @@ impl, const D: usize> CircuitBuilder { self.reduce(&result) } + pub fn mul_many_nonnative( + &mut self, + to_mul: &[NonNativeTarget], + ) -> NonNativeTarget { + if to_mul.len() == 1 { + return to_mul[0].clone(); + } + + let mut result = self.mul_biguint(&to_mul[0].value, &to_mul[1].value); + for i in 2..to_mul.len() { + result = self.mul_biguint(&result, &to_mul[i].value); + } + + self.reduce(&result) + } + pub fn neg_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { let zero_target = self.constant_biguint(&BigUint::zero()); let zero_ff = self.biguint_to_nonnative(&zero_target); @@ -104,21 +140,27 @@ impl, const D: usize> CircuitBuilder { pub fn inv_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { let num_limbs = x.value.num_limbs(); let inv_biguint = self.add_virtual_biguint_target(num_limbs); + let div = self.add_virtual_biguint_target(num_limbs); + + self.add_simple_generator(NonNativeInverseGenerator:: { + x: x.clone(), + inv: inv_biguint.clone(), + div: div.clone(), + _phantom: PhantomData, + }); + + let product = self.mul_biguint(&x.value, &inv_biguint); + + let modulus = self.constant_biguint(&FF::order()); + let mod_times_div = self.mul_biguint(&modulus, &div); + let one = self.constant_biguint(&BigUint::one()); + let expected_product = self.add_biguint(&mod_times_div, &one); + self.connect_biguint(&product, &expected_product); + let inv = NonNativeTarget:: { value: inv_biguint, _phantom: PhantomData, }; - - self.add_simple_generator(NonNativeInverseGenerator:: { - x: x.clone(), - inv: inv.clone(), - _phantom: PhantomData, - }); - - let product = self.mul_nonnative(x, &inv); - let one = self.constant_nonnative(FF::ONE); - self.connect_nonnative(&product, &one); - inv } @@ -138,10 +180,70 @@ impl, const D: usize> CircuitBuilder { /// Returns `x % |FF|` as a `NonNativeTarget`. fn reduce(&mut self, x: &BigUintTarget) -> NonNativeTarget { + println!("NUM LIMBS: {}", x.limbs.len()); + let before = self.num_gates(); + let modulus = FF::order(); let order_target = self.constant_biguint(&modulus); let value = self.rem_biguint(x, &order_target); + println!("NUMBER OF GATES: {}", self.num_gates() - before); + println!("OUTPUT LIMBS: {}", value.limbs.len()); + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + + /// Returns `x % |FF|` as a `NonNativeTarget`. + fn reduce_by_bits(&mut self, x: &BigUintTarget) -> NonNativeTarget { + println!("NUM LIMBS: {}", x.limbs.len()); + let before = self.num_gates(); + + let mut powers_of_two = Vec::new(); + let mut cur_power_of_two = FF::ONE; + let two = FF::TWO; + let mut max_num_limbs = 0; + for _ in 0..(x.limbs.len() * 32) { + let cur_power = self.constant_biguint(&cur_power_of_two.to_biguint()); + max_num_limbs = max_num_limbs.max(cur_power.limbs.len()); + powers_of_two.push(cur_power.limbs); + + cur_power_of_two *= two; + } + + let mut result_limbs_unreduced = vec![self.zero(); max_num_limbs]; + for i in 0..x.limbs.len() { + let this_limb = x.limbs[i]; + let bits = self.split_le(this_limb.0, 32); + for b in 0..bits.len() { + let this_power = powers_of_two[32 * i + b].clone(); + for x in 0..this_power.len() { + result_limbs_unreduced[x] = self.mul_add(bits[b].target, this_power[x].0, result_limbs_unreduced[x]); + } + } + } + + let mut result_limbs_reduced = Vec::new(); + let mut carry = self.zero_u32(); + for i in 0..result_limbs_unreduced.len() { + println!("{}", i); + let (low, high) = self.split_to_u32(result_limbs_unreduced[i]); + let (cur, overflow) = self.add_u32(carry, low); + let (new_carry, _) = self.add_many_u32(&[overflow, high, carry]); + result_limbs_reduced.push(cur); + carry = new_carry; + } + result_limbs_reduced.push(carry); + + let value = BigUintTarget { + limbs: result_limbs_reduced, + }; + + println!("NUMBER OF GATES: {}", self.num_gates() - before); + println!("OUTPUT LIMBS: {}", value.limbs.len()); + NonNativeTarget { value, _phantom: PhantomData, @@ -190,7 +292,8 @@ impl, const D: usize> CircuitBuilder { #[derive(Debug)] struct NonNativeInverseGenerator, const D: usize, FF: Field> { x: NonNativeTarget, - inv: NonNativeTarget, + inv: BigUintTarget, + div: BigUintTarget, _phantom: PhantomData, } @@ -205,7 +308,14 @@ impl, const D: usize, FF: Field> SimpleGenerator let x = witness.get_nonnative_target(self.x.clone()); let inv = x.inverse(); - out_buffer.set_nonnative_target(self.inv.clone(), inv); + let x_biguint = x.to_biguint(); + let inv_biguint = inv.to_biguint(); + let prod = x_biguint * &inv_biguint; + let modulus = FF::order(); + let (div, _rem) = prod.div_rem(&modulus); + + out_buffer.set_biguint_target(self.div.clone(), div); + out_buffer.set_biguint_target(self.inv.clone(), inv_biguint); } } @@ -247,6 +357,43 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + #[test] + fn test_nonnative_many_adds() -> Result<()> { + type FF = Secp256K1Base; + let a_ff = FF::rand(); + let b_ff = FF::rand(); + let c_ff = FF::rand(); + let d_ff = FF::rand(); + let e_ff = FF::rand(); + let f_ff = FF::rand(); + let g_ff = FF::rand(); + let h_ff = FF::rand(); + let sum_ff = a_ff + b_ff + c_ff + d_ff + e_ff + f_ff + g_ff + h_ff; + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let a = builder.constant_nonnative(a_ff); + let b = builder.constant_nonnative(b_ff); + let c = builder.constant_nonnative(c_ff); + let d = builder.constant_nonnative(d_ff); + let e = builder.constant_nonnative(e_ff); + let f = builder.constant_nonnative(f_ff); + let g = builder.constant_nonnative(g_ff); + let h = builder.constant_nonnative(h_ff); + let all = [a, b, c, d, e, f, g, h]; + let sum = builder.add_many_nonnative(&all); + + let sum_expected = builder.constant_nonnative(sum_ff); + builder.connect_nonnative(&sum, &sum_expected); + + let data = builder.build(); + let proof = data.prove(pw).unwrap(); + verify(proof, &data.verifier_only, &data.common) + } + #[test] fn test_nonnative_sub() -> Result<()> { type FF = Secp256K1Base; @@ -285,6 +432,7 @@ mod tests { let x_ff = FF::rand(); let y_ff = FF::rand(); let product_ff = x_ff * y_ff; + println!("PRODUCT FF: {:?}", product_ff); let config = CircuitConfig::standard_recursion_config(); let pw = PartialWitness::new(); @@ -302,6 +450,38 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } + fn test_nonnative_many_muls_helper(num: usize) { + type FF = Secp256K1Base; + + type F = GoldilocksField; + let config = CircuitConfig::standard_recursion_config(); + let mut unop_builder = CircuitBuilder::::new(config.clone()); + let mut op_builder = CircuitBuilder::::new(config); + + println!("NUM: {}", num); + + let ffs: Vec<_> = (0..num).map(|_| FF::rand()).collect(); + + let op_targets: Vec<_> = ffs.iter().map(|&x| op_builder.constant_nonnative(x)).collect(); + op_builder.mul_many_nonnative(&op_targets); + println!("OPTIMIZED GATE COUNT: {}", op_builder.num_gates()); + + let unop_targets: Vec<_> = ffs.iter().map(|&x| unop_builder.constant_nonnative(x)).collect(); + let mut result = unop_targets[0].clone(); + for i in 1..unop_targets.len() { + result = unop_builder.mul_nonnative(&result, &unop_targets[i]); + } + + println!("UNOPTIMIZED GATE COUNT: {}", unop_builder.num_gates()); + } + + #[test] + fn test_nonnative_many_muls() { + for num in 2..10 { + test_nonnative_many_muls_helper(num); + } + } + #[test] fn test_nonnative_neg() -> Result<()> { type FF = Secp256K1Base; diff --git a/plonky2/src/gates/mod.rs b/plonky2/src/gates/mod.rs index a3f92615..ac4600fa 100644 --- a/plonky2/src/gates/mod.rs +++ b/plonky2/src/gates/mod.rs @@ -3,6 +3,8 @@ pub mod arithmetic_base; pub mod arithmetic_extension; +pub mod binary_arithmetic; +pub mod binary_subtraction; pub mod arithmetic_u32; pub mod assert_le; pub mod base_sum; diff --git a/plonky2/src/iop/generator.rs b/plonky2/src/iop/generator.rs index 368232fd..994ba62b 100644 --- a/plonky2/src/iop/generator.rs +++ b/plonky2/src/iop/generator.rs @@ -7,6 +7,7 @@ use plonky2_field::field_types::Field; use crate::gadgets::arithmetic_u32::U32Target; use crate::gadgets::biguint::BigUintTarget; +use crate::gadgets::binary_arithmetic::BinaryTarget; use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::{HashOut, HashOutTarget, RichField}; use crate::iop::ext_target::ExtensionTarget; @@ -161,10 +162,14 @@ impl GeneratedValues { self.target_values.push((target, value)) } - fn set_u32_target(&mut self, target: U32Target, value: u32) { + pub fn set_u32_target(&mut self, target: U32Target, value: u32) { self.set_target(target.0, F::from_canonical_u32(value)) } + pub fn set_binary_target(&mut self, target: BinaryTarget, value: F) { + self.set_target(target.0, value) + } + pub fn set_biguint_target(&mut self, target: BigUintTarget, value: BigUint) { let mut limbs = value.to_u32_digits(); assert!(target.num_limbs() >= limbs.len()); diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index d9bcc1cf..c9d01abe 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -19,6 +19,8 @@ use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; use crate::gates::arithmetic_base::ArithmeticGate; use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::arithmetic_u32::U32ArithmeticGate; +use crate::gates::binary_arithmetic::BinaryArithmeticGate; +use crate::gates::binary_subtraction::BinarySubtractionGate; use crate::gates::constant::ConstantGate; use crate::gates::gate::{Gate, GateInstance, GateRef, PrefixedGate}; use crate::gates::gate_tree::Tree; @@ -222,6 +224,11 @@ impl, const D: usize> CircuitBuilder { let gate_ref = GateRef::new(gate_type); self.gates.insert(gate_ref.clone()); + /*println!("ADDING GATE {}: {:?}", index, gate_ref); + if index == 145 { + panic!(); + }*/ + self.gate_instances.push(GateInstance { gate_ref, constants, @@ -346,6 +353,11 @@ impl, const D: usize> CircuitBuilder { U32Target(self.constant(F::from_canonical_u32(c))) } + /// Returns a BinaryTarget for the value `c`, which is assumed to be at most BITS bits. + pub fn constant_binary(&mut self, c: F) -> BinaryTarget { + BinaryTarget(self.constant(c)) + } + /// If the given target is a constant (i.e. it was created by the `constant(F)` method), returns /// its constant value. Otherwise, returns `None`. pub fn target_as_constant(&self, target: Target) -> Option { @@ -818,10 +830,14 @@ pub struct BatchedGates, const D: usize> { /// The `U32ArithmeticGate` currently being filled (so new u32 arithmetic operations will be added to this gate before creating a new one) pub(crate) current_u32_arithmetic_gate: Option<(usize, usize)>, - /// The `U32SubtractionGate` currently being filled (so new u32 subtraction operations will be added to this gate before creating a new one) pub(crate) current_u32_subtraction_gate: Option<(usize, usize)>, + /// A map `b -> (g, i)` from `b` bits to an available `BinaryArithmeticGate` for number of bits `b`. + pub(crate) free_binary_arithmetic_gate: HashMap, + /// A map `b -> (g, i)` from `b` bits to an available `BinarySubtractionGate` for number of bits `b`. + pub(crate) free_binary_subtraction_gate: HashMap, + /// An available `ConstantGate` instance, if any. pub(crate) free_constant: Option<(usize, usize)>, } @@ -836,6 +852,8 @@ impl, const D: usize> BatchedGates { current_switch_gates: Vec::new(), current_u32_arithmetic_gate: None, current_u32_subtraction_gate: None, + free_binary_arithmetic_gate: HashMap::new(), + free_binary_subtraction_gate: HashMap::new(), free_constant: None, } } @@ -931,8 +949,8 @@ impl, const D: usize> CircuitBuilder { (gate, i) } - /// Finds the last available random access gate with the given `vec_size` or add one if there aren't any. - /// Returns `(g,i)` such that there is a random access gate with the given `vec_size` at index + /// Finds the last available random access gate with the given `bits` or add one if there aren't any. + /// Returns `(g,i)` such that there is a random access gate for the given `bits` at index /// `g` and the gate's `i`-th random access is available. pub(crate) fn find_random_access_gate(&mut self, bits: usize) -> (usize, usize) { let (gate, i) = self @@ -1031,6 +1049,64 @@ impl, const D: usize> CircuitBuilder { (gate_index, copy) } + + /// Finds the last available binary arithmetic with the given `bits` or add one if there aren't any. + /// Returns `(g,i)` such that there is a binary arithmetic for the given `bits` at index + /// `g` and the gate's `i`-th copy is available. + pub(crate) fn find_binary_arithmetic_gate(&mut self) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_binary_arithmetic_gate + .get(&BITS) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + BinaryArithmeticGate::::new_from_config(&self.config), + vec![], + ); + (gate, 0) + }); + + // Update `free_binary_arithmetic` with new values. + if i + 1 < BinaryArithmeticGate::::new_from_config(&self.config).num_ops { + self.batched_gates + .free_random_access + .insert(BITS, (gate, i + 1)); + } else { + self.batched_gates.free_random_access.remove(&BITS); + } + + (gate, i) + } + + /// Finds the last available binary subtraction with the given `bits` or add one if there aren't any. + /// Returns `(g,i)` such that there is a binary subtraction for the given `bits` at index + /// `g` and the gate's `i`-th copy is available. + pub(crate) fn find_binary_subtraction_gate(&mut self) -> (usize, usize) { + let (gate, i) = self + .batched_gates + .free_binary_subtraction_gate + .get(&BITS) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate( + BinarySubtractionGate::::new_from_config(&self.config), + vec![], + ); + (gate, 0) + }); + + // Update `free_binary_subtraction` with new values. + if i + 1 < BinarySubtractionGate::::new_from_config(&self.config).num_ops { + self.batched_gates + .free_random_access + .insert(BITS, (gate, i + 1)); + } else { + self.batched_gates.free_random_access.remove(&BITS); + } + + (gate, i) + } /// Returns the gate index and copy index of a free `ConstantGate` slot, potentially adding a /// new `ConstantGate` if needed.