{add|mul}_three_extension

This commit is contained in:
wborgeaud 2021-07-21 17:29:05 +02:00
parent b59d497964
commit d870a36dee

View File

@ -125,6 +125,22 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
ExtensionAlgebraTarget(res.try_into().unwrap())
}
/// Add 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s.
pub fn add_three_extension(
&mut self,
a: ExtensionTarget<D>,
b: ExtensionTarget<D>,
c: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let one = self.one_extension();
let gate = self.num_gates();
let first_out =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_first_output());
self.double_arithmetic_extension(F::ONE, F::ONE, one, a, b, one, c, first_out)
.1
}
/// Add `n` `ExtensionTarget`s with `n/2 + 1` `ArithmeticExtensionGate`s.
pub fn add_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> {
let zero = self.zero_extension();
let mut terms = terms.to_vec();
@ -246,6 +262,22 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
ExtensionAlgebraTarget(res)
}
/// Multiply 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s.
pub fn mul_three_extension(
&mut self,
a: ExtensionTarget<D>,
b: ExtensionTarget<D>,
c: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let zero = self.zero_extension();
let gate = self.num_gates();
let first_out =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_first_output());
self.double_arithmetic_extension(F::ONE, F::ZERO, a, b, zero, c, first_out, zero)
.1
}
/// Multiply `n` `ExtensionTarget`s with `n/2 + 1` `ArithmeticExtensionGate`s.
pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> {
let one = self.one_extension();
let mut terms = terms.to_vec();
@ -487,8 +519,8 @@ mod tests {
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut pw = PartialWitness::new();
let vs = FF::rand_vec(20);
let ts = builder.add_virtual_extension_targets(20);
let vs = FF::rand_vec(3);
let ts = builder.add_virtual_extension_targets(3);
for (&v, &t) in vs.iter().zip(&ts) {
pw.set_extension_target(t, v);
}
@ -500,10 +532,12 @@ mod tests {
}
acc
};
let mul2 = builder.constant_extension(vs.into_iter().product());
let mul2 = builder.mul_three_extension(ts[0], ts[1], ts[2]);
let mul3 = builder.constant_extension(vs.into_iter().product());
builder.assert_equal_extension(mul0, mul1);
builder.assert_equal_extension(mul1, mul2);
builder.assert_equal_extension(mul2, mul3);
let data = builder.build();
let proof = data.prove(pw)?;