diff --git a/plonky2/src/gadgets/biguint.rs b/plonky2/src/gadgets/biguint.rs index 24d3da7d..1a829f97 100644 --- a/plonky2/src/gadgets/biguint.rs +++ b/plonky2/src/gadgets/biguint.rs @@ -40,6 +40,10 @@ impl, const D: usize> CircuitBuilder { BigUintTarget { limbs } } + pub fn zero_biguint(&mut self) -> BigUintTarget { + self.constant_biguint(&BigUint::zero()) + } + pub fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) { let min_limbs = lhs.num_limbs().min(rhs.num_limbs()); for i in 0..min_limbs { @@ -159,6 +163,18 @@ impl, const D: usize> CircuitBuilder { } } + pub fn mul_biguint_by_bool( + &mut self, + a: &BigUintTarget, + b: BoolTarget, + ) -> BigUintTarget { + let t = b.target; + + BigUintTarget { + limbs: a.limbs.iter().map(|l| U32Target(self.mul(l.0, t))).collect() + } + } + // Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function). pub fn mul_add_biguint( &mut self, @@ -396,11 +412,11 @@ mod tests { let y = builder.constant_biguint(&y_value); let (div, rem) = builder.div_rem_biguint(&x, &y); - // let expected_div = builder.constant_biguint(&expected_div_value); - // let expected_rem = builder.constant_biguint(&expected_rem_value); + let expected_div = builder.constant_biguint(&expected_div_value); + let expected_rem = builder.constant_biguint(&expected_rem_value); - // builder.connect_biguint(&div, &expected_div); - // builder.connect_biguint(&rem, &expected_rem); + builder.connect_biguint(&div, &expected_div); + builder.connect_biguint(&rem, &expected_rem); let data = builder.build::(); let proof = data.prove(pw).unwrap(); diff --git a/plonky2/src/gadgets/curve.rs b/plonky2/src/gadgets/curve.rs index 63e96721..907aa5e3 100644 --- a/plonky2/src/gadgets/curve.rs +++ b/plonky2/src/gadgets/curve.rs @@ -100,6 +100,7 @@ impl, const D: usize> CircuitBuilder { p1: &AffinePointTarget, p2: &AffinePointTarget, ) -> AffinePointTarget { + let before = self.num_gates(); let AffinePointTarget { x: x1, y: y1 } = p1; let AffinePointTarget { x: x2, y: y2 } = p2; @@ -123,6 +124,7 @@ impl, const D: usize> CircuitBuilder { let x3_norm = self.mul_nonnative(&x3, &z3_inv); let y3_norm = self.mul_nonnative(&y3, &z3_inv); + println!("NUM GATES: {}", self.num_gates() - before); AffinePointTarget { x: x3_norm, y: y3_norm, @@ -310,7 +312,6 @@ mod tests { } #[test] - #[ignore] fn test_curve_mul() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; @@ -345,7 +346,6 @@ mod tests { } #[test] - #[ignore] fn test_curve_random() -> Result<()> { const D: usize = 2; type C = PoseidonGoldilocksConfig; diff --git a/plonky2/src/gadgets/nonnative.rs b/plonky2/src/gadgets/nonnative.rs index fd3dab87..1006b513 100644 --- a/plonky2/src/gadgets/nonnative.rs +++ b/plonky2/src/gadgets/nonnative.rs @@ -60,16 +60,45 @@ impl, const D: usize> CircuitBuilder { } } - // Add two `NonNativeTarget`s. pub fn add_nonnative( &mut self, a: &NonNativeTarget, b: &NonNativeTarget, ) -> NonNativeTarget { - let result = self.add_biguint(&a.value, &b.value); + let sum = self.add_virtual_nonnative_target::(); + let overflow = self.add_virtual_bool_target(); - // TODO: reduce add result with only one conditional subtraction - self.reduce(&result) + self.add_simple_generator(NonNativeAdditionGenerator:: { + a: a.clone(), + b: b.clone(), + sum: sum.clone(), + overflow: overflow.clone(), + _phantom: PhantomData, + }); + + let sum_expected = self.add_biguint(&a.value, &b.value); + + let modulus = self.constant_biguint(&FF::order()); + let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow); + let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow); + self.connect_biguint(&sum_expected, &sum_actual); + + sum + } + + pub fn mul_nonnative_by_bool( + &mut self, + a: &NonNativeTarget, + b: BoolTarget, + ) -> NonNativeTarget { + let t = b.target; + + NonNativeTarget { + value: BigUintTarget { + limbs: a.value.limbs.iter().map(|l| U32Target(self.mul(l.0, t))).collect() + }, + _phantom: PhantomData, + } } pub fn add_many_nonnative( @@ -80,12 +109,28 @@ impl, const D: usize> CircuitBuilder { return to_add[0].clone(); } - let mut result = self.add_biguint(&to_add[0].value, &to_add[1].value); - for i in 2..to_add.len() { - result = self.add_biguint(&result, &to_add[i].value); - } + let sum = self.add_virtual_nonnative_target::(); + let overflow = self.add_virtual_u32_target(); + let summands = to_add.to_vec(); - self.reduce(&result) + self.add_simple_generator(NonNativeMultipleAddsGenerator:: { + summands: summands.clone(), + sum: sum.clone(), + overflow: overflow.clone(), + _phantom: PhantomData, + }); + + let sum_expected = summands.iter().fold(self.zero_biguint(), |a, b| self.add_biguint(&a, &b.value)); + + let modulus = self.constant_biguint(&FF::order()); + let overflow_biguint = BigUintTarget { + limbs: vec![overflow], + }; + let mod_times_overflow = self.mul_biguint(&modulus, &overflow_biguint); + let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow); + self.connect_biguint(&sum_expected, &sum_actual); + + sum } // Subtract two `NonNativeTarget`s. @@ -188,59 +233,6 @@ impl, const D: usize> CircuitBuilder { } } - /// Returns `x % |FF|` as a `NonNativeTarget`. - /*fn reduce_by_bits(&mut self, x: &BigUintTarget) -> NonNativeTarget { - let before = self.num_gates(); - - let mut powers_of_two = Vec::new(); - let mut cur_power_of_two = FF::ONE; - let two = FF::TWO; - let mut max_num_limbs = 0; - for _ in 0..(x.limbs.len() * 32) { - let cur_power = self.constant_biguint(&cur_power_of_two.to_biguint()); - max_num_limbs = max_num_limbs.max(cur_power.limbs.len()); - powers_of_two.push(cur_power.limbs); - - cur_power_of_two *= two; - } - - let mut result_limbs_unreduced = vec![self.zero(); max_num_limbs]; - for i in 0..x.limbs.len() { - let this_limb = x.limbs[i]; - let bits = self.split_le(this_limb.0, 32); - for b in 0..bits.len() { - let this_power = powers_of_two[32 * i + b].clone(); - for x in 0..this_power.len() { - result_limbs_unreduced[x] = self.mul_add(bits[b].target, this_power[x].0, result_limbs_unreduced[x]); - } - } - } - - let mut result_limbs_reduced = Vec::new(); - let mut carry = self.zero_u32(); - for i in 0..result_limbs_unreduced.len() { - println!("{}", i); - let (low, high) = self.split_to_u32(result_limbs_unreduced[i]); - let (cur, overflow) = self.add_u32(carry, low); - let (new_carry, _) = self.add_many_u32(&[overflow, high, carry]); - result_limbs_reduced.push(cur); - carry = new_carry; - } - result_limbs_reduced.push(carry); - - let value = BigUintTarget { - limbs: result_limbs_reduced, - }; - - println!("NUMBER OF GATES: {}", self.num_gates() - before); - println!("OUTPUT LIMBS: {}", value.limbs.len()); - - NonNativeTarget { - value, - _phantom: PhantomData, - } - }*/ - #[allow(dead_code)] fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { let x_biguint = self.nonnative_to_biguint(x); @@ -280,6 +272,74 @@ impl, const D: usize> CircuitBuilder { } } +#[derive(Debug)] +struct NonNativeAdditionGenerator, const D: usize, FF: Field> { + a: NonNativeTarget, + b: NonNativeTarget, + sum: NonNativeTarget, + overflow: BoolTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: Field> SimpleGenerator + for NonNativeAdditionGenerator +{ + fn dependencies(&self) -> Vec { + self.a.value.limbs.iter().cloned().chain(self.b.value.limbs.clone()) + .map(|l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = witness.get_nonnative_target(self.a.clone()); + let b = witness.get_nonnative_target(self.b.clone()); + let a_biguint = a.to_biguint(); + let b_biguint = b.to_biguint(); + let sum_biguint = a_biguint + b_biguint; + let modulus = FF::order(); + let (overflow, sum_reduced) = if sum_biguint > modulus { + (true, sum_biguint - modulus) + } else { + (false, sum_biguint) + }; + + out_buffer.set_biguint_target(self.sum.value.clone(), sum_reduced); + out_buffer.set_bool_target(self.overflow, overflow); + } +} + +#[derive(Debug)] +struct NonNativeMultipleAddsGenerator, const D: usize, FF: Field> { + summands: Vec>, + sum: NonNativeTarget, + overflow: U32Target, + _phantom: PhantomData, +} + +impl, const D: usize, FF: Field> SimpleGenerator + for NonNativeMultipleAddsGenerator +{ + fn dependencies(&self) -> Vec { + self.summands.iter().map(|summand| summand.value.limbs.iter().map(|limb| limb.0)) + .flatten() + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let summands: Vec<_> = self.summands.iter().map(|summand| witness.get_nonnative_target(summand.clone())).collect(); + let summand_biguints: Vec<_> = summands.iter().map(|summand| summand.to_biguint()).collect(); + + let sum_biguint = summand_biguints.iter().fold(BigUint::zero(), |a, b| a + b.clone()); + + let modulus = FF::order(); + let (overflow_biguint, sum_reduced) = sum_biguint.div_rem(&modulus); + let overflow = overflow_biguint.to_u64_digits()[0] as u32; + + out_buffer.set_biguint_target(self.sum.value.clone(), sum_reduced); + out_buffer.set_u32_target(self.overflow, overflow); + } +} + #[derive(Debug)] struct NonNativeInverseGenerator, const D: usize, FF: Field> { x: NonNativeTarget, @@ -310,6 +370,8 @@ impl, const D: usize, FF: Field> SimpleGenerator } } + + #[cfg(test)] mod tests { use anyhow::Result; diff --git a/plonky2/src/gates/mod.rs b/plonky2/src/gates/mod.rs index a7591648..177db7cf 100644 --- a/plonky2/src/gates/mod.rs +++ b/plonky2/src/gates/mod.rs @@ -11,6 +11,7 @@ pub mod binary_arithmetic; pub mod binary_subtraction; pub mod comparison; pub mod constant; +// pub mod curve_double; pub mod exponentiation; pub mod gate; pub mod gate_tree; diff --git a/plonky2/src/iop/generator.rs b/plonky2/src/iop/generator.rs index d4e37dcb..5f8b8a5f 100644 --- a/plonky2/src/iop/generator.rs +++ b/plonky2/src/iop/generator.rs @@ -162,6 +162,10 @@ impl GeneratedValues { self.target_values.push((target, value)) } + pub fn set_bool_target(&mut self, target: BoolTarget, value: bool) { + self.set_target(target.target, F::from_bool(value)) + } + pub fn set_u32_target(&mut self, target: U32Target, value: u32) { self.set_target(target.0, F::from_canonical_u32(value)) }