more efficient nonnative subtraction

This commit is contained in:
Nicholas Ward 2022-01-18 14:59:39 -08:00
parent 8d3662692e
commit ddf5ee5d1f

View File

@ -146,12 +146,24 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
a: &NonNativeTarget<FF>,
b: &NonNativeTarget<FF>,
) -> NonNativeTarget<FF> {
let order = self.constant_biguint(&FF::order());
let a_plus_order = self.add_biguint(&order, &a.value);
let result = self.sub_biguint(&a_plus_order, &b.value);
let diff = self.add_virtual_nonnative_target::<FF>();
let overflow = self.add_virtual_bool_target();
// TODO: reduce sub result with only one conditional addition?
self.reduce(&result)
self.add_simple_generator(NonNativeSubtractionGenerator::<F, D, FF> {
a: a.clone(),
b: b.clone(),
diff: diff.clone(),
overflow: overflow.clone(),
_phantom: PhantomData,
});
let diff_plus_b = self.add_biguint(&diff.value, &b.value);
let modulus = self.constant_biguint(&FF::order());
let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow);
let diff_plus_b_reduced = self.sub_biguint(&diff_plus_b, &mod_times_overflow);
self.connect_biguint(&a.value, &diff_plus_b_reduced);
diff
}
pub fn mul_nonnative<FF: Field>(
@ -363,6 +375,47 @@ impl<F: RichField + Extendable<D>, const D: usize, FF: Field> SimpleGenerator<F>
}
}
#[derive(Debug)]
struct NonNativeSubtractionGenerator<F: RichField + Extendable<D>, const D: usize, FF: Field> {
a: NonNativeTarget<FF>,
b: NonNativeTarget<FF>,
diff: NonNativeTarget<FF>,
overflow: BoolTarget,
_phantom: PhantomData<F>,
}
impl<F: RichField + Extendable<D>, const D: usize, FF: Field> SimpleGenerator<F>
for NonNativeSubtractionGenerator<F, D, FF>
{
fn dependencies(&self) -> Vec<Target> {
self.a
.value
.limbs
.iter()
.cloned()
.chain(self.b.value.limbs.clone())
.map(|l| l.0)
.collect()
}
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let a = witness.get_nonnative_target(self.a.clone());
let b = witness.get_nonnative_target(self.b.clone());
let a_biguint = a.to_biguint();
let b_biguint = b.to_biguint();
let modulus = FF::order();
let (diff_biguint, overflow) = if a_biguint > b_biguint {
(a_biguint - b_biguint, false)
} else {
(modulus + a_biguint - b_biguint, true)
};
out_buffer.set_biguint_target(self.diff.value.clone(), diff_biguint);
out_buffer.set_bool_target(self.overflow, overflow);
}
}
#[derive(Debug)]
struct NonNativeInverseGenerator<F: RichField + Extendable<D>, const D: usize, FF: Field> {
x: NonNativeTarget<FF>,