using refs in right places; and lots of fixes

This commit is contained in:
Nicholas Ward 2021-10-26 15:56:08 -07:00
parent bfe201d951
commit f7ce33b7ae
2 changed files with 80 additions and 134 deletions

View File

@ -26,7 +26,7 @@ impl BigUintTarget {
}
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget {
pub fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget {
let limb_values = value.to_u32_digits();
let mut limbs = Vec::new();
for i in 0..limb_values.len() {
@ -38,7 +38,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
BigUintTarget { limbs }
}
fn connect_biguint(&mut self, lhs: BigUintTarget, rhs: BigUintTarget) {
pub fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) {
let min_limbs = lhs.num_limbs().min(rhs.num_limbs());
for i in 0..min_limbs {
self.connect_u32(lhs.get_limb(i), rhs.get_limb(i));
@ -52,7 +52,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
fn pad_biguints(
pub fn pad_biguints<'a>(
&mut self,
a: BigUintTarget,
b: BigUintTarget,
@ -74,7 +74,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
fn cmp_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BoolTarget {
pub fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget {
let (padded_a, padded_b) = self.pad_biguints(a.clone(), b.clone());
let a_vec = padded_a.limbs.iter().map(|&x| x.0).collect();
@ -83,7 +83,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.list_le(a_vec, b_vec, 32)
}
fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget {
pub fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget {
let limbs = (0..num_limbs)
.map(|_| self.add_virtual_u32_target())
.collect();
@ -92,7 +92,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
// Add two `BigUintTarget`s.
pub fn add_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget {
pub fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let num_limbs = a.num_limbs().max(b.num_limbs());
let mut combined_limbs = vec![];
@ -121,7 +121,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
// Subtract two `BigUintTarget`s. We assume that the first is larger than the second.
pub fn sub_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget {
pub fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let num_limbs = a.limbs.len();
debug_assert!(b.limbs.len() == num_limbs);
@ -140,7 +140,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
pub fn mul_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget {
pub fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let num_limbs = a.limbs.len();
debug_assert!(b.limbs.len() == num_limbs);
@ -170,8 +170,8 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn div_rem_biguint(
&mut self,
a: BigUintTarget,
b: BigUintTarget,
a: &BigUintTarget,
b: &BigUintTarget,
) -> (BigUintTarget, BigUintTarget) {
let num_limbs = a.limbs.len();
let div = self.add_virtual_biguint_target(num_limbs);
@ -185,22 +185,22 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
_phantom: PhantomData,
});
let div_b = self.mul_biguint(div.clone(), b.clone());
let div_b_plus_rem = self.add_biguint(div_b, rem.clone());
self.connect_biguint(a, div_b_plus_rem);
let div_b = self.mul_biguint(&div, &b);
let div_b_plus_rem = self.add_biguint(&div_b, &rem);
self.connect_biguint(&a, &div_b_plus_rem);
let cmp_rem_b = self.cmp_biguint(rem.clone(), b);
let cmp_rem_b = self.cmp_biguint(&rem, b);
self.assert_one(cmp_rem_b.target);
(div, rem)
}
pub fn div_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget {
pub fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let (div, _rem) = self.div_rem_biguint(a, b);
div
}
pub fn rem_biguint(&mut self, a: BigUintTarget, b: BigUintTarget) -> BigUintTarget {
pub fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget {
let (_div, rem) = self.div_rem_biguint(a, b);
rem
}
@ -259,12 +259,12 @@ mod tests {
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_biguint(x_value);
let y = builder.constant_biguint(y_value);
let z = builder.add_biguint(x, y);
let expected_z = builder.constant_biguint(expected_z_value);
let x = builder.constant_biguint(&x_value);
let y = builder.constant_biguint(&y_value);
let z = builder.add_biguint(&x, &y);
let expected_z = builder.constant_biguint(&expected_z_value);
builder.connect_biguint(z, expected_z);
builder.connect_biguint(&z, &expected_z);
let data = builder.build();
let proof = data.prove(pw).unwrap();
@ -282,12 +282,12 @@ mod tests {
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_biguint(x_value);
let y = builder.constant_biguint(y_value);
let z = builder.sub_biguint(x, y);
let expected_z = builder.constant_biguint(expected_z_value);
let x = builder.constant_biguint(&x_value);
let y = builder.constant_biguint(&y_value);
let z = builder.sub_biguint(&x, &y);
let expected_z = builder.constant_biguint(&expected_z_value);
builder.connect_biguint(z, expected_z);
builder.connect_biguint(&z, &expected_z);
let data = builder.build();
let proof = data.prove(pw).unwrap();
@ -305,12 +305,12 @@ mod tests {
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_biguint(x_value);
let y = builder.constant_biguint(y_value);
let z = builder.mul_biguint(x, y);
let expected_z = builder.constant_biguint(expected_z_value);
let x = builder.constant_biguint(&x_value);
let y = builder.constant_biguint(&y_value);
let z = builder.mul_biguint(&x, &y);
let expected_z = builder.constant_biguint(&expected_z_value);
builder.connect_biguint(z, expected_z);
builder.connect_biguint(&z, &expected_z);
let data = builder.build();
let proof = data.prove(pw).unwrap();
@ -327,9 +327,9 @@ mod tests {
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_biguint(x_value);
let y = builder.constant_biguint(y_value);
let cmp = builder.cmp_biguint(x, y);
let x = builder.constant_biguint(&x_value);
let y = builder.constant_biguint(&y_value);
let cmp = builder.cmp_biguint(&x, &y);
let expected_cmp = builder.constant_bool(false);
builder.connect(cmp.target, expected_cmp.target);
@ -350,15 +350,15 @@ mod tests {
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let x = builder.constant_biguint(x_value);
let y = builder.constant_biguint(y_value);
let (div, rem) = builder.div_rem_biguint(x, y);
let x = builder.constant_biguint(&x_value);
let y = builder.constant_biguint(&y_value);
let (div, rem) = builder.div_rem_biguint(&x, &y);
let expected_div = builder.constant_biguint(expected_div_value);
let expected_rem = builder.constant_biguint(expected_rem_value);
let expected_div = builder.constant_biguint(&expected_div_value);
let expected_rem = builder.constant_biguint(&expected_rem_value);
builder.connect_biguint(div, expected_div);
builder.connect_biguint(rem, expected_rem);
builder.connect_biguint(&div, &expected_div);
builder.connect_biguint(&rem, &expected_rem);
let data = builder.build();
let proof = data.prove(pw).unwrap();

View File

@ -1,5 +1,6 @@
use std::marker::PhantomData;
use crate::gadgets::biguint::BigUintTarget;
use crate::field::field_types::RichField;
use crate::field::{extension_field::Extendable, field_types::Field};
use crate::gadgets::arithmetic_u32::U32Target;
@ -20,122 +21,67 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
.collect()
}
// Add two `ForeignFieldTarget`s.
pub fn add_nonnative<FF: Field>(
&mut self,
a: ForeignFieldTarget<FF>,
b: ForeignFieldTarget<FF>,
) -> ForeignFieldTarget<FF> {
let num_limbs = a.limbs.len();
debug_assert!(b.limbs.len() == num_limbs);
let mut combined_limbs = self.add_virtual_u32_targets(num_limbs + 1);
let mut carry = self.zero_u32();
for i in 0..num_limbs {
let (new_limb, new_carry) =
self.add_three_u32(carry.clone(), a.limbs[i].clone(), b.limbs[i].clone());
carry = new_carry;
combined_limbs[i] = new_limb;
}
combined_limbs[num_limbs] = carry;
let reduced_limbs = self.reduce_add_result::<FF>(combined_limbs);
ForeignFieldTarget {
limbs: reduced_limbs,
_phantom: PhantomData,
pub fn ff_to_biguint<FF: Field>(&mut self, x: &ForeignFieldTarget<FF>) -> BigUintTarget {
BigUintTarget {
limbs: x.limbs.clone(),
}
}
/// Reduces the result of a non-native addition.
pub fn reduce_add_result<FF: Field>(&mut self, limbs: Vec<U32Target>) -> Vec<U32Target> {
let num_limbs = limbs.len();
// Add two `ForeignFieldTarget`s.
pub fn add_nonnative<FF: Field>(
&mut self,
a: &ForeignFieldTarget<FF>,
b: &ForeignFieldTarget<FF>,
) -> ForeignFieldTarget<FF> {
let a_biguint = self.ff_to_biguint(a);
let b_biguint = self.ff_to_biguint(b);
let result = self.add_biguint(&a_biguint, &b_biguint);
let mut modulus_limbs = self.order_u32_limbs::<FF>();
modulus_limbs.push(self.zero_u32());
let needs_reduce = self.list_le_u32(modulus_limbs.clone(), limbs.clone());
let mut to_subtract = vec![];
for i in 0..num_limbs {
let (low, _high) = self.mul_u32(modulus_limbs[i], U32Target(needs_reduce.target));
to_subtract.push(low);
}
let mut reduced_limbs = vec![];
let mut borrow = self.zero_u32();
for i in 0..num_limbs {
let (result, new_borrow) = self.sub_u32(limbs[i], to_subtract[i], borrow);
reduced_limbs[i] = result;
borrow = new_borrow;
}
// Borrow should be zero here.
reduced_limbs
self.reduce(&result)
}
// Subtract two `ForeignFieldTarget`s. We assume that the first is larger than the second.
pub fn sub_nonnative<FF: Field>(
&mut self,
a: ForeignFieldTarget<FF>,
b: ForeignFieldTarget<FF>,
a: &ForeignFieldTarget<FF>,
b: &ForeignFieldTarget<FF>,
) -> ForeignFieldTarget<FF> {
let num_limbs = a.limbs.len();
debug_assert!(b.limbs.len() == num_limbs);
let a_biguint = self.ff_to_biguint(a);
let b_biguint = self.ff_to_biguint(b);
let result = self.sub_biguint(&a_biguint, &b_biguint);
let mut result_limbs = vec![];
let mut borrow = self.zero_u32();
for i in 0..num_limbs {
let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow);
result_limbs[i] = result;
borrow = new_borrow;
}
// Borrow should be zero here.
ForeignFieldTarget {
limbs: result_limbs,
_phantom: PhantomData,
}
self.reduce(&result)
}
pub fn mul_nonnative<FF: Field>(
&mut self,
a: ForeignFieldTarget<FF>,
b: ForeignFieldTarget<FF>,
a: &ForeignFieldTarget<FF>,
b: &ForeignFieldTarget<FF>,
) -> ForeignFieldTarget<FF> {
let num_limbs = a.limbs.len();
debug_assert!(b.limbs.len() == num_limbs);
let a_biguint = self.ff_to_biguint(a);
let b_biguint = self.ff_to_biguint(b);
let result = self.mul_biguint(&a_biguint, &b_biguint);
let mut combined_limbs = self.add_virtual_u32_targets(2 * num_limbs - 1);
let mut to_add = vec![vec![]; 2 * num_limbs];
for i in 0..num_limbs {
for j in 0..num_limbs {
let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]);
to_add[i + j].push(product);
to_add[i + j + 1].push(carry);
}
}
self.reduce(&result)
}
let mut combined_limbs = vec![];
let mut carry = self.zero_u32();
for i in 0..2 * num_limbs {
to_add[i].push(carry);
let (new_result, new_carry) = self.add_many_u32(to_add[i].clone());
combined_limbs.push(new_result);
carry = new_carry;
}
combined_limbs.push(carry);
let reduced_limbs = self.reduce_mul_result::<FF>(combined_limbs);
/// Returns `x % |FF|` as a `ForeignFieldTarget`.
fn reduce<FF: Field>(
&mut self,
x: &BigUintTarget,
) -> ForeignFieldTarget<FF> {
let modulus = FF::order();
let order_target = self.constant_biguint(&modulus);
let value = self.rem_biguint(x, &order_target);
ForeignFieldTarget {
limbs: reduced_limbs,
limbs: value.limbs,
_phantom: PhantomData,
}
}
pub fn reduce_mul_result<FF: Field>(&mut self, limbs: Vec<U32Target>) -> Vec<U32Target> {
todo!()
fn reduce_ff<FF: Field>(&mut self, x: &ForeignFieldTarget<FF>) -> ForeignFieldTarget<FF> {
let x_biguint = self.ff_to_biguint(x);
self.reduce(&x_biguint)
}
}