From f7ce33b7aef73a51ab7db64fffb7b0bc2ed0710b Mon Sep 17 00:00:00 2001 From: Nicholas Ward Date: Tue, 26 Oct 2021 15:56:08 -0700 Subject: [PATCH] using refs in right places; and lots of fixes --- src/gadgets/biguint.rs | 82 ++++++++++++------------ src/gadgets/nonnative.rs | 132 ++++++++++++--------------------------- 2 files changed, 80 insertions(+), 134 deletions(-) diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs index 9b3895a7..1ccf9c3a 100644 --- a/src/gadgets/biguint.rs +++ b/src/gadgets/biguint.rs @@ -26,7 +26,7 @@ impl BigUintTarget { } impl, const D: usize> CircuitBuilder { - fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget { + pub fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget { let limb_values = value.to_u32_digits(); let mut limbs = Vec::new(); for i in 0..limb_values.len() { @@ -38,7 +38,7 @@ impl, const D: usize> CircuitBuilder { BigUintTarget { limbs } } - fn connect_biguint(&mut self, lhs: BigUintTarget, rhs: BigUintTarget) { + pub fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) { let min_limbs = lhs.num_limbs().min(rhs.num_limbs()); for i in 0..min_limbs { self.connect_u32(lhs.get_limb(i), rhs.get_limb(i)); @@ -52,7 +52,7 @@ impl, const D: usize> CircuitBuilder { } } - fn pad_biguints( + pub fn pad_biguints<'a>( &mut self, a: BigUintTarget, b: BigUintTarget, @@ -74,7 +74,7 @@ impl, const D: usize> CircuitBuilder { } } - fn cmp_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BoolTarget { + pub fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget { let (padded_a, padded_b) = self.pad_biguints(a.clone(), b.clone()); let a_vec = padded_a.limbs.iter().map(|&x| x.0).collect(); @@ -83,7 +83,7 @@ impl, const D: usize> CircuitBuilder { self.list_le(a_vec, b_vec, 32) } - fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget { + pub fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget { let limbs = (0..num_limbs) .map(|_| self.add_virtual_u32_target()) .collect(); @@ -92,7 +92,7 @@ impl, const D: usize> CircuitBuilder { } // Add two `BigUintTarget`s. - pub fn add_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { + pub fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let num_limbs = a.num_limbs().max(b.num_limbs()); let mut combined_limbs = vec![]; @@ -121,7 +121,7 @@ impl, const D: usize> CircuitBuilder { } // Subtract two `BigUintTarget`s. We assume that the first is larger than the second. - pub fn sub_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { + pub fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let num_limbs = a.limbs.len(); debug_assert!(b.limbs.len() == num_limbs); @@ -140,7 +140,7 @@ impl, const D: usize> CircuitBuilder { } } - pub fn mul_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { + pub fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let num_limbs = a.limbs.len(); debug_assert!(b.limbs.len() == num_limbs); @@ -170,8 +170,8 @@ impl, const D: usize> CircuitBuilder { pub fn div_rem_biguint( &mut self, - a: BigUintTarget, - b: BigUintTarget, + a: &BigUintTarget, + b: &BigUintTarget, ) -> (BigUintTarget, BigUintTarget) { let num_limbs = a.limbs.len(); let div = self.add_virtual_biguint_target(num_limbs); @@ -185,22 +185,22 @@ impl, const D: usize> CircuitBuilder { _phantom: PhantomData, }); - let div_b = self.mul_biguint(div.clone(), b.clone()); - let div_b_plus_rem = self.add_biguint(div_b, rem.clone()); - self.connect_biguint(a, div_b_plus_rem); + let div_b = self.mul_biguint(&div, &b); + let div_b_plus_rem = self.add_biguint(&div_b, &rem); + self.connect_biguint(&a, &div_b_plus_rem); - let cmp_rem_b = self.cmp_biguint(rem.clone(), b); + let cmp_rem_b = self.cmp_biguint(&rem, b); self.assert_one(cmp_rem_b.target); (div, rem) } - pub fn div_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { + pub fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let (div, _rem) = self.div_rem_biguint(a, b); div } - pub fn rem_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget { + pub fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { let (_div, rem) = self.div_rem_biguint(a, b); rem } @@ -259,12 +259,12 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = builder.constant_biguint(x_value); - let y = builder.constant_biguint(y_value); - let z = builder.add_biguint(x, y); - let expected_z = builder.constant_biguint(expected_z_value); + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let z = builder.add_biguint(&x, &y); + let expected_z = builder.constant_biguint(&expected_z_value); - builder.connect_biguint(z, expected_z); + builder.connect_biguint(&z, &expected_z); let data = builder.build(); let proof = data.prove(pw).unwrap(); @@ -282,12 +282,12 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = builder.constant_biguint(x_value); - let y = builder.constant_biguint(y_value); - let z = builder.sub_biguint(x, y); - let expected_z = builder.constant_biguint(expected_z_value); + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let z = builder.sub_biguint(&x, &y); + let expected_z = builder.constant_biguint(&expected_z_value); - builder.connect_biguint(z, expected_z); + builder.connect_biguint(&z, &expected_z); let data = builder.build(); let proof = data.prove(pw).unwrap(); @@ -305,12 +305,12 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = builder.constant_biguint(x_value); - let y = builder.constant_biguint(y_value); - let z = builder.mul_biguint(x, y); - let expected_z = builder.constant_biguint(expected_z_value); + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let z = builder.mul_biguint(&x, &y); + let expected_z = builder.constant_biguint(&expected_z_value); - builder.connect_biguint(z, expected_z); + builder.connect_biguint(&z, &expected_z); let data = builder.build(); let proof = data.prove(pw).unwrap(); @@ -327,9 +327,9 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = builder.constant_biguint(x_value); - let y = builder.constant_biguint(y_value); - let cmp = builder.cmp_biguint(x, y); + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let cmp = builder.cmp_biguint(&x, &y); let expected_cmp = builder.constant_bool(false); builder.connect(cmp.target, expected_cmp.target); @@ -350,15 +350,15 @@ mod tests { let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - let x = builder.constant_biguint(x_value); - let y = builder.constant_biguint(y_value); - let (div, rem) = builder.div_rem_biguint(x, y); + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let (div, rem) = builder.div_rem_biguint(&x, &y); - let expected_div = builder.constant_biguint(expected_div_value); - let expected_rem = builder.constant_biguint(expected_rem_value); + let expected_div = builder.constant_biguint(&expected_div_value); + let expected_rem = builder.constant_biguint(&expected_rem_value); - builder.connect_biguint(div, expected_div); - builder.connect_biguint(rem, expected_rem); + builder.connect_biguint(&div, &expected_div); + builder.connect_biguint(&rem, &expected_rem); let data = builder.build(); let proof = data.prove(pw).unwrap(); diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs index ff344697..61d0ac5c 100644 --- a/src/gadgets/nonnative.rs +++ b/src/gadgets/nonnative.rs @@ -1,5 +1,6 @@ use std::marker::PhantomData; +use crate::gadgets::biguint::BigUintTarget; use crate::field::field_types::RichField; use crate::field::{extension_field::Extendable, field_types::Field}; use crate::gadgets::arithmetic_u32::U32Target; @@ -20,122 +21,67 @@ impl, const D: usize> CircuitBuilder { .collect() } - // Add two `ForeignFieldTarget`s. - pub fn add_nonnative( - &mut self, - a: ForeignFieldTarget, - b: ForeignFieldTarget, - ) -> ForeignFieldTarget { - let num_limbs = a.limbs.len(); - debug_assert!(b.limbs.len() == num_limbs); - - let mut combined_limbs = self.add_virtual_u32_targets(num_limbs + 1); - let mut carry = self.zero_u32(); - for i in 0..num_limbs { - let (new_limb, new_carry) = - self.add_three_u32(carry.clone(), a.limbs[i].clone(), b.limbs[i].clone()); - carry = new_carry; - combined_limbs[i] = new_limb; - } - combined_limbs[num_limbs] = carry; - - let reduced_limbs = self.reduce_add_result::(combined_limbs); - ForeignFieldTarget { - limbs: reduced_limbs, - _phantom: PhantomData, + pub fn ff_to_biguint(&mut self, x: &ForeignFieldTarget) -> BigUintTarget { + BigUintTarget { + limbs: x.limbs.clone(), } } - /// Reduces the result of a non-native addition. - pub fn reduce_add_result(&mut self, limbs: Vec) -> Vec { - let num_limbs = limbs.len(); + // Add two `ForeignFieldTarget`s. + pub fn add_nonnative( + &mut self, + a: &ForeignFieldTarget, + b: &ForeignFieldTarget, + ) -> ForeignFieldTarget { + let a_biguint = self.ff_to_biguint(a); + let b_biguint = self.ff_to_biguint(b); + let result = self.add_biguint(&a_biguint, &b_biguint); - let mut modulus_limbs = self.order_u32_limbs::(); - modulus_limbs.push(self.zero_u32()); - - let needs_reduce = self.list_le_u32(modulus_limbs.clone(), limbs.clone()); - - let mut to_subtract = vec![]; - for i in 0..num_limbs { - let (low, _high) = self.mul_u32(modulus_limbs[i], U32Target(needs_reduce.target)); - to_subtract.push(low); - } - - let mut reduced_limbs = vec![]; - - let mut borrow = self.zero_u32(); - for i in 0..num_limbs { - let (result, new_borrow) = self.sub_u32(limbs[i], to_subtract[i], borrow); - reduced_limbs[i] = result; - borrow = new_borrow; - } - // Borrow should be zero here. - - reduced_limbs + self.reduce(&result) } // Subtract two `ForeignFieldTarget`s. We assume that the first is larger than the second. pub fn sub_nonnative( &mut self, - a: ForeignFieldTarget, - b: ForeignFieldTarget, + a: &ForeignFieldTarget, + b: &ForeignFieldTarget, ) -> ForeignFieldTarget { - let num_limbs = a.limbs.len(); - debug_assert!(b.limbs.len() == num_limbs); + let a_biguint = self.ff_to_biguint(a); + let b_biguint = self.ff_to_biguint(b); + let result = self.sub_biguint(&a_biguint, &b_biguint); - let mut result_limbs = vec![]; - - let mut borrow = self.zero_u32(); - for i in 0..num_limbs { - let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow); - result_limbs[i] = result; - borrow = new_borrow; - } - // Borrow should be zero here. - - ForeignFieldTarget { - limbs: result_limbs, - _phantom: PhantomData, - } + self.reduce(&result) } pub fn mul_nonnative( &mut self, - a: ForeignFieldTarget, - b: ForeignFieldTarget, + a: &ForeignFieldTarget, + b: &ForeignFieldTarget, ) -> ForeignFieldTarget { - let num_limbs = a.limbs.len(); - debug_assert!(b.limbs.len() == num_limbs); + let a_biguint = self.ff_to_biguint(a); + let b_biguint = self.ff_to_biguint(b); + let result = self.mul_biguint(&a_biguint, &b_biguint); - let mut combined_limbs = self.add_virtual_u32_targets(2 * num_limbs - 1); - let mut to_add = vec![vec![]; 2 * num_limbs]; - for i in 0..num_limbs { - for j in 0..num_limbs { - let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]); - to_add[i + j].push(product); - to_add[i + j + 1].push(carry); - } - } + self.reduce(&result) + } - let mut combined_limbs = vec![]; - let mut carry = self.zero_u32(); - for i in 0..2 * num_limbs { - to_add[i].push(carry); - let (new_result, new_carry) = self.add_many_u32(to_add[i].clone()); - combined_limbs.push(new_result); - carry = new_carry; - } - combined_limbs.push(carry); - - let reduced_limbs = self.reduce_mul_result::(combined_limbs); + /// Returns `x % |FF|` as a `ForeignFieldTarget`. + fn reduce( + &mut self, + x: &BigUintTarget, + ) -> ForeignFieldTarget { + let modulus = FF::order(); + let order_target = self.constant_biguint(&modulus); + let value = self.rem_biguint(x, &order_target); ForeignFieldTarget { - limbs: reduced_limbs, + limbs: value.limbs, _phantom: PhantomData, } } - pub fn reduce_mul_result(&mut self, limbs: Vec) -> Vec { - todo!() + fn reduce_ff(&mut self, x: &ForeignFieldTarget) -> ForeignFieldTarget { + let x_biguint = self.ff_to_biguint(x); + self.reduce(&x_biguint) } }