diff --git a/plonky2/src/gadgets/nonnative.rs b/plonky2/src/gadgets/nonnative.rs index 83333491..8726fde7 100644 --- a/plonky2/src/gadgets/nonnative.rs +++ b/plonky2/src/gadgets/nonnative.rs @@ -146,12 +146,24 @@ impl, const D: usize> CircuitBuilder { a: &NonNativeTarget, b: &NonNativeTarget, ) -> NonNativeTarget { - let order = self.constant_biguint(&FF::order()); - let a_plus_order = self.add_biguint(&order, &a.value); - let result = self.sub_biguint(&a_plus_order, &b.value); + let diff = self.add_virtual_nonnative_target::(); + let overflow = self.add_virtual_bool_target(); - // TODO: reduce sub result with only one conditional addition? - self.reduce(&result) + self.add_simple_generator(NonNativeSubtractionGenerator:: { + a: a.clone(), + b: b.clone(), + diff: diff.clone(), + overflow: overflow.clone(), + _phantom: PhantomData, + }); + + let diff_plus_b = self.add_biguint(&diff.value, &b.value); + let modulus = self.constant_biguint(&FF::order()); + let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow); + let diff_plus_b_reduced = self.sub_biguint(&diff_plus_b, &mod_times_overflow); + self.connect_biguint(&a.value, &diff_plus_b_reduced); + + diff } pub fn mul_nonnative( @@ -363,6 +375,47 @@ impl, const D: usize, FF: Field> SimpleGenerator } } +#[derive(Debug)] +struct NonNativeSubtractionGenerator, const D: usize, FF: Field> { + a: NonNativeTarget, + b: NonNativeTarget, + diff: NonNativeTarget, + overflow: BoolTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: Field> SimpleGenerator + for NonNativeSubtractionGenerator +{ + fn dependencies(&self) -> Vec { + self.a + .value + .limbs + .iter() + .cloned() + .chain(self.b.value.limbs.clone()) + .map(|l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = witness.get_nonnative_target(self.a.clone()); + let b = witness.get_nonnative_target(self.b.clone()); + let a_biguint = a.to_biguint(); + let b_biguint = b.to_biguint(); + + let modulus = FF::order(); + let (diff_biguint, overflow) = if a_biguint > b_biguint { + (a_biguint - b_biguint, false) + } else { + (modulus + a_biguint - b_biguint, true) + }; + + out_buffer.set_biguint_target(self.diff.value.clone(), diff_biguint); + out_buffer.set_bool_target(self.overflow, overflow); + } +} + #[derive(Debug)] struct NonNativeInverseGenerator, const D: usize, FF: Field> { x: NonNativeTarget,