diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 33165ce2..fa6df6a9 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -125,6 +125,22 @@ impl, const D: usize> CircuitBuilder { ExtensionAlgebraTarget(res.try_into().unwrap()) } + /// Add 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s. + pub fn add_three_extension( + &mut self, + a: ExtensionTarget, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + let one = self.one_extension(); + let gate = self.num_gates(); + let first_out = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_output()); + self.double_arithmetic_extension(F::ONE, F::ONE, one, a, b, one, c, first_out) + .1 + } + + /// Add `n` `ExtensionTarget`s with `n/2 + 1` `ArithmeticExtensionGate`s. pub fn add_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { let zero = self.zero_extension(); let mut terms = terms.to_vec(); @@ -246,6 +262,22 @@ impl, const D: usize> CircuitBuilder { ExtensionAlgebraTarget(res) } + /// Multiply 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s. + pub fn mul_three_extension( + &mut self, + a: ExtensionTarget, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + let zero = self.zero_extension(); + let gate = self.num_gates(); + let first_out = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_output()); + self.double_arithmetic_extension(F::ONE, F::ZERO, a, b, zero, c, first_out, zero) + .1 + } + + /// Multiply `n` `ExtensionTarget`s with `n/2 + 1` `ArithmeticExtensionGate`s. pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget]) -> ExtensionTarget { let one = self.one_extension(); let mut terms = terms.to_vec(); @@ -487,8 +519,8 @@ mod tests { let mut builder = CircuitBuilder::::new(config); let mut pw = PartialWitness::new(); - let vs = FF::rand_vec(20); - let ts = builder.add_virtual_extension_targets(20); + let vs = FF::rand_vec(3); + let ts = builder.add_virtual_extension_targets(3); for (&v, &t) in vs.iter().zip(&ts) { pw.set_extension_target(t, v); } @@ -500,10 +532,12 @@ mod tests { } acc }; - let mul2 = builder.constant_extension(vs.into_iter().product()); + let mul2 = builder.mul_three_extension(ts[0], ts[1], ts[2]); + let mul3 = builder.constant_extension(vs.into_iter().product()); builder.assert_equal_extension(mul0, mul1); builder.assert_equal_extension(mul1, mul2); + builder.assert_equal_extension(mul2, mul3); let data = builder.build(); let proof = data.prove(pw)?;