more efficient nonnative add and multi-add

This commit is contained in:
Nicholas Ward 2022-01-18 14:07:47 -08:00
parent facb5661f3
commit 50c24dfe8a
5 changed files with 151 additions and 68 deletions

View File

@ -40,6 +40,10 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
BigUintTarget { limbs }
}
pub fn zero_biguint(&mut self) -> BigUintTarget {
self.constant_biguint(&BigUint::zero())
}
pub fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) {
let min_limbs = lhs.num_limbs().min(rhs.num_limbs());
for i in 0..min_limbs {
@ -159,6 +163,18 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
pub fn mul_biguint_by_bool(
&mut self,
a: &BigUintTarget,
b: BoolTarget,
) -> BigUintTarget {
let t = b.target;
BigUintTarget {
limbs: a.limbs.iter().map(|l| U32Target(self.mul(l.0, t))).collect()
}
}
// Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function).
pub fn mul_add_biguint(
&mut self,
@ -396,11 +412,11 @@ mod tests {
let y = builder.constant_biguint(&y_value);
let (div, rem) = builder.div_rem_biguint(&x, &y);
// let expected_div = builder.constant_biguint(&expected_div_value);
// let expected_rem = builder.constant_biguint(&expected_rem_value);
let expected_div = builder.constant_biguint(&expected_div_value);
let expected_rem = builder.constant_biguint(&expected_rem_value);
// builder.connect_biguint(&div, &expected_div);
// builder.connect_biguint(&rem, &expected_rem);
builder.connect_biguint(&div, &expected_div);
builder.connect_biguint(&rem, &expected_rem);
let data = builder.build::<C>();
let proof = data.prove(pw).unwrap();

View File

@ -100,6 +100,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
p1: &AffinePointTarget<C>,
p2: &AffinePointTarget<C>,
) -> AffinePointTarget<C> {
let before = self.num_gates();
let AffinePointTarget { x: x1, y: y1 } = p1;
let AffinePointTarget { x: x2, y: y2 } = p2;
@ -123,6 +124,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let x3_norm = self.mul_nonnative(&x3, &z3_inv);
let y3_norm = self.mul_nonnative(&y3, &z3_inv);
println!("NUM GATES: {}", self.num_gates() - before);
AffinePointTarget {
x: x3_norm,
y: y3_norm,
@ -310,7 +312,6 @@ mod tests {
}
#[test]
#[ignore]
fn test_curve_mul() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
@ -345,7 +346,6 @@ mod tests {
}
#[test]
#[ignore]
fn test_curve_random() -> Result<()> {
const D: usize = 2;
type C = PoseidonGoldilocksConfig;

View File

@ -60,16 +60,45 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
// Add two `NonNativeTarget`s.
pub fn add_nonnative<FF: Field>(
&mut self,
a: &NonNativeTarget<FF>,
b: &NonNativeTarget<FF>,
) -> NonNativeTarget<FF> {
let result = self.add_biguint(&a.value, &b.value);
let sum = self.add_virtual_nonnative_target::<FF>();
let overflow = self.add_virtual_bool_target();
// TODO: reduce add result with only one conditional subtraction
self.reduce(&result)
self.add_simple_generator(NonNativeAdditionGenerator::<F, D, FF> {
a: a.clone(),
b: b.clone(),
sum: sum.clone(),
overflow: overflow.clone(),
_phantom: PhantomData,
});
let sum_expected = self.add_biguint(&a.value, &b.value);
let modulus = self.constant_biguint(&FF::order());
let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow);
let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow);
self.connect_biguint(&sum_expected, &sum_actual);
sum
}
pub fn mul_nonnative_by_bool<FF: Field>(
&mut self,
a: &NonNativeTarget<FF>,
b: BoolTarget,
) -> NonNativeTarget<FF> {
let t = b.target;
NonNativeTarget {
value: BigUintTarget {
limbs: a.value.limbs.iter().map(|l| U32Target(self.mul(l.0, t))).collect()
},
_phantom: PhantomData,
}
}
pub fn add_many_nonnative<FF: Field>(
@ -80,12 +109,28 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
return to_add[0].clone();
}
let mut result = self.add_biguint(&to_add[0].value, &to_add[1].value);
for i in 2..to_add.len() {
result = self.add_biguint(&result, &to_add[i].value);
}
let sum = self.add_virtual_nonnative_target::<FF>();
let overflow = self.add_virtual_u32_target();
let summands = to_add.to_vec();
self.reduce(&result)
self.add_simple_generator(NonNativeMultipleAddsGenerator::<F, D, FF> {
summands: summands.clone(),
sum: sum.clone(),
overflow: overflow.clone(),
_phantom: PhantomData,
});
let sum_expected = summands.iter().fold(self.zero_biguint(), |a, b| self.add_biguint(&a, &b.value));
let modulus = self.constant_biguint(&FF::order());
let overflow_biguint = BigUintTarget {
limbs: vec![overflow],
};
let mod_times_overflow = self.mul_biguint(&modulus, &overflow_biguint);
let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow);
self.connect_biguint(&sum_expected, &sum_actual);
sum
}
// Subtract two `NonNativeTarget`s.
@ -188,59 +233,6 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
/// Returns `x % |FF|` as a `NonNativeTarget`.
/*fn reduce_by_bits<FF: Field>(&mut self, x: &BigUintTarget) -> NonNativeTarget<FF> {
let before = self.num_gates();
let mut powers_of_two = Vec::new();
let mut cur_power_of_two = FF::ONE;
let two = FF::TWO;
let mut max_num_limbs = 0;
for _ in 0..(x.limbs.len() * 32) {
let cur_power = self.constant_biguint(&cur_power_of_two.to_biguint());
max_num_limbs = max_num_limbs.max(cur_power.limbs.len());
powers_of_two.push(cur_power.limbs);
cur_power_of_two *= two;
}
let mut result_limbs_unreduced = vec![self.zero(); max_num_limbs];
for i in 0..x.limbs.len() {
let this_limb = x.limbs[i];
let bits = self.split_le(this_limb.0, 32);
for b in 0..bits.len() {
let this_power = powers_of_two[32 * i + b].clone();
for x in 0..this_power.len() {
result_limbs_unreduced[x] = self.mul_add(bits[b].target, this_power[x].0, result_limbs_unreduced[x]);
}
}
}
let mut result_limbs_reduced = Vec::new();
let mut carry = self.zero_u32();
for i in 0..result_limbs_unreduced.len() {
println!("{}", i);
let (low, high) = self.split_to_u32(result_limbs_unreduced[i]);
let (cur, overflow) = self.add_u32(carry, low);
let (new_carry, _) = self.add_many_u32(&[overflow, high, carry]);
result_limbs_reduced.push(cur);
carry = new_carry;
}
result_limbs_reduced.push(carry);
let value = BigUintTarget {
limbs: result_limbs_reduced,
};
println!("NUMBER OF GATES: {}", self.num_gates() - before);
println!("OUTPUT LIMBS: {}", value.limbs.len());
NonNativeTarget {
value,
_phantom: PhantomData,
}
}*/
#[allow(dead_code)]
fn reduce_nonnative<FF: Field>(&mut self, x: &NonNativeTarget<FF>) -> NonNativeTarget<FF> {
let x_biguint = self.nonnative_to_biguint(x);
@ -280,6 +272,74 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
#[derive(Debug)]
struct NonNativeAdditionGenerator<F: RichField + Extendable<D>, const D: usize, FF: Field> {
a: NonNativeTarget<FF>,
b: NonNativeTarget<FF>,
sum: NonNativeTarget<FF>,
overflow: BoolTarget,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize, FF: Field> SimpleGenerator<F>
for NonNativeAdditionGenerator<F, D, FF>
{
fn dependencies(&self) -> Vec<Target> {
self.a.value.limbs.iter().cloned().chain(self.b.value.limbs.clone())
.map(|l| l.0)
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let a = witness.get_nonnative_target(self.a.clone());
let b = witness.get_nonnative_target(self.b.clone());
let a_biguint = a.to_biguint();
let b_biguint = b.to_biguint();
let sum_biguint = a_biguint + b_biguint;
let modulus = FF::order();
let (overflow, sum_reduced) = if sum_biguint > modulus {
(true, sum_biguint - modulus)
} else {
(false, sum_biguint)
};
out_buffer.set_biguint_target(self.sum.value.clone(), sum_reduced);
out_buffer.set_bool_target(self.overflow, overflow);
}
}
#[derive(Debug)]
struct NonNativeMultipleAddsGenerator<F: RichField + Extendable<D>, const D: usize, FF: Field> {
summands: Vec<NonNativeTarget<FF>>,
sum: NonNativeTarget<FF>,
overflow: U32Target,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize, FF: Field> SimpleGenerator<F>
for NonNativeMultipleAddsGenerator<F, D, FF>
{
fn dependencies(&self) -> Vec<Target> {
self.summands.iter().map(|summand| summand.value.limbs.iter().map(|limb| limb.0))
.flatten()
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let summands: Vec<_> = self.summands.iter().map(|summand| witness.get_nonnative_target(summand.clone())).collect();
let summand_biguints: Vec<_> = summands.iter().map(|summand| summand.to_biguint()).collect();
let sum_biguint = summand_biguints.iter().fold(BigUint::zero(), |a, b| a + b.clone());
let modulus = FF::order();
let (overflow_biguint, sum_reduced) = sum_biguint.div_rem(&modulus);
let overflow = overflow_biguint.to_u64_digits()[0] as u32;
out_buffer.set_biguint_target(self.sum.value.clone(), sum_reduced);
out_buffer.set_u32_target(self.overflow, overflow);
}
}
#[derive(Debug)]
struct NonNativeInverseGenerator<F: RichField + Extendable<D>, const D: usize, FF: Field> {
x: NonNativeTarget<FF>,
@ -310,6 +370,8 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: Field> SimpleGenerator<F>
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;

View File

@ -11,6 +11,7 @@ pub mod binary_arithmetic;
pub mod binary_subtraction;
pub mod comparison;
pub mod constant;
// pub mod curve_double;
pub mod exponentiation;
pub mod gate;
pub mod gate_tree;

View File

@ -162,6 +162,10 @@ impl<F: Field> GeneratedValues<F> {
self.target_values.push((target, value))
}
pub fn set_bool_target(&mut self, target: BoolTarget, value: bool) {
self.set_target(target.target, F::from_bool(value))
}
pub fn set_u32_target(&mut self, target: U32Target, value: u32) {
self.set_target(target.0, F::from_canonical_u32(value))
}