mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-07 08:13:11 +00:00
Merge pull request #117 from mir-protocol/optimize_mul_many
Remove fixed multiplicand in `ArithmeticExtensionGate`
This commit is contained in:
commit
2f46ddc4e5
@ -16,7 +16,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
|||||||
|
|
||||||
/// Computes `x^3`.
|
/// Computes `x^3`.
|
||||||
pub fn cube(&mut self, x: Target) -> Target {
|
pub fn cube(&mut self, x: Target) -> Target {
|
||||||
self.mul_many(&[x, x, x])
|
let xe = self.convert_to_ext(x);
|
||||||
|
self.mul_three_extension(xe, xe, xe).to_target_array()[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Computes `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`.
|
/// Computes `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`.
|
||||||
@ -123,13 +124,14 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
|||||||
self.arithmetic(F::ONE, x, one, F::ONE, y)
|
self.arithmetic(F::ONE, x, one, F::ONE, y)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Add `n` `Target`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s.
|
||||||
// TODO: Can be made `2*D` times more efficient by using all wires of an `ArithmeticExtensionGate`.
|
// TODO: Can be made `2*D` times more efficient by using all wires of an `ArithmeticExtensionGate`.
|
||||||
pub fn add_many(&mut self, terms: &[Target]) -> Target {
|
pub fn add_many(&mut self, terms: &[Target]) -> Target {
|
||||||
let mut sum = self.zero();
|
let terms_ext = terms
|
||||||
for term in terms {
|
.iter()
|
||||||
sum = self.add(sum, *term);
|
.map(|&t| self.convert_to_ext(t))
|
||||||
}
|
.collect::<Vec<_>>();
|
||||||
sum
|
self.add_many_extension(&terms_ext).to_target_array()[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Computes `x - y`.
|
/// Computes `x - y`.
|
||||||
@ -145,12 +147,13 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
|||||||
self.arithmetic(F::ONE, x, y, F::ZERO, x)
|
self.arithmetic(F::ONE, x, y, F::ZERO, x)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Multiply `n` `Target`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s.
|
||||||
pub fn mul_many(&mut self, terms: &[Target]) -> Target {
|
pub fn mul_many(&mut self, terms: &[Target]) -> Target {
|
||||||
let mut product = self.one();
|
let terms_ext = terms
|
||||||
for term in terms {
|
.iter()
|
||||||
product = self.mul(product, *term);
|
.map(|&t| self.convert_to_ext(t))
|
||||||
}
|
.collect::<Vec<_>>();
|
||||||
product
|
self.mul_many_extension(&terms_ext).to_target_array()[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Exponentiate `base` to the power of `2^power_log`.
|
/// Exponentiate `base` to the power of `2^power_log`.
|
||||||
|
|||||||
@ -17,37 +17,47 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
|||||||
&mut self,
|
&mut self,
|
||||||
const_0: F,
|
const_0: F,
|
||||||
const_1: F,
|
const_1: F,
|
||||||
fixed_multiplicand: ExtensionTarget<D>,
|
first_multiplicand_0: ExtensionTarget<D>,
|
||||||
multiplicand_0: ExtensionTarget<D>,
|
first_multiplicand_1: ExtensionTarget<D>,
|
||||||
addend_0: ExtensionTarget<D>,
|
first_addend: ExtensionTarget<D>,
|
||||||
multiplicand_1: ExtensionTarget<D>,
|
second_multiplicand_0: ExtensionTarget<D>,
|
||||||
addend_1: ExtensionTarget<D>,
|
second_multiplicand_1: ExtensionTarget<D>,
|
||||||
|
second_addend: ExtensionTarget<D>,
|
||||||
) -> (ExtensionTarget<D>, ExtensionTarget<D>) {
|
) -> (ExtensionTarget<D>, ExtensionTarget<D>) {
|
||||||
let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![const_0, const_1]);
|
let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![const_0, const_1]);
|
||||||
|
|
||||||
let wire_fixed_multiplicand = ExtensionTarget::from_range(
|
let wire_first_multiplicand_0 = ExtensionTarget::from_range(
|
||||||
gate,
|
gate,
|
||||||
ArithmeticExtensionGate::<D>::wires_fixed_multiplicand(),
|
ArithmeticExtensionGate::<D>::wires_first_multiplicand_0(),
|
||||||
);
|
);
|
||||||
let wire_multiplicand_0 =
|
let wire_first_multiplicand_1 = ExtensionTarget::from_range(
|
||||||
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_multiplicand_0());
|
gate,
|
||||||
let wire_addend_0 =
|
ArithmeticExtensionGate::<D>::wires_first_multiplicand_1(),
|
||||||
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_addend_0());
|
);
|
||||||
let wire_multiplicand_1 =
|
let wire_first_addend =
|
||||||
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_multiplicand_1());
|
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_first_addend());
|
||||||
let wire_addend_1 =
|
let wire_second_multiplicand_0 = ExtensionTarget::from_range(
|
||||||
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_addend_1());
|
gate,
|
||||||
let wire_output_0 =
|
ArithmeticExtensionGate::<D>::wires_second_multiplicand_0(),
|
||||||
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_output_0());
|
);
|
||||||
let wire_output_1 =
|
let wire_second_multiplicand_1 = ExtensionTarget::from_range(
|
||||||
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_output_1());
|
gate,
|
||||||
|
ArithmeticExtensionGate::<D>::wires_second_multiplicand_1(),
|
||||||
|
);
|
||||||
|
let wire_second_addend =
|
||||||
|
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_second_addend());
|
||||||
|
let wire_first_output =
|
||||||
|
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_first_output());
|
||||||
|
let wire_second_output =
|
||||||
|
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_second_output());
|
||||||
|
|
||||||
self.route_extension(fixed_multiplicand, wire_fixed_multiplicand);
|
self.route_extension(first_multiplicand_0, wire_first_multiplicand_0);
|
||||||
self.route_extension(multiplicand_0, wire_multiplicand_0);
|
self.route_extension(first_multiplicand_1, wire_first_multiplicand_1);
|
||||||
self.route_extension(addend_0, wire_addend_0);
|
self.route_extension(first_addend, wire_first_addend);
|
||||||
self.route_extension(multiplicand_1, wire_multiplicand_1);
|
self.route_extension(second_multiplicand_0, wire_second_multiplicand_0);
|
||||||
self.route_extension(addend_1, wire_addend_1);
|
self.route_extension(second_multiplicand_1, wire_second_multiplicand_1);
|
||||||
(wire_output_0, wire_output_1)
|
self.route_extension(second_addend, wire_second_addend);
|
||||||
|
(wire_first_output, wire_second_output)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn arithmetic_extension(
|
pub fn arithmetic_extension(
|
||||||
@ -67,6 +77,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
|||||||
addend,
|
addend,
|
||||||
zero,
|
zero,
|
||||||
zero,
|
zero,
|
||||||
|
zero,
|
||||||
)
|
)
|
||||||
.0
|
.0
|
||||||
}
|
}
|
||||||
@ -80,6 +91,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
|||||||
self.arithmetic_extension(F::ONE, F::ONE, one, a, b)
|
self.arithmetic_extension(F::ONE, F::ONE, one, a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns `(a0+b0, a1+b1)`.
|
||||||
pub fn add_two_extension(
|
pub fn add_two_extension(
|
||||||
&mut self,
|
&mut self,
|
||||||
a0: ExtensionTarget<D>,
|
a0: ExtensionTarget<D>,
|
||||||
@ -88,7 +100,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
|||||||
b1: ExtensionTarget<D>,
|
b1: ExtensionTarget<D>,
|
||||||
) -> (ExtensionTarget<D>, ExtensionTarget<D>) {
|
) -> (ExtensionTarget<D>, ExtensionTarget<D>) {
|
||||||
let one = self.one_extension();
|
let one = self.one_extension();
|
||||||
self.double_arithmetic_extension(F::ONE, F::ONE, one, a0, b0, a1, b1)
|
self.double_arithmetic_extension(F::ONE, F::ONE, one, a0, b0, one, a1, b1)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_ext_algebra(
|
pub fn add_ext_algebra(
|
||||||
@ -113,20 +125,39 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
|||||||
ExtensionAlgebraTarget(res.try_into().unwrap())
|
ExtensionAlgebraTarget(res.try_into().unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Add 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s.
|
||||||
|
pub fn add_three_extension(
|
||||||
|
&mut self,
|
||||||
|
a: ExtensionTarget<D>,
|
||||||
|
b: ExtensionTarget<D>,
|
||||||
|
c: ExtensionTarget<D>,
|
||||||
|
) -> ExtensionTarget<D> {
|
||||||
|
let one = self.one_extension();
|
||||||
|
let gate = self.num_gates();
|
||||||
|
let first_out =
|
||||||
|
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_first_output());
|
||||||
|
self.double_arithmetic_extension(F::ONE, F::ONE, one, a, b, one, c, first_out)
|
||||||
|
.1
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s.
|
||||||
pub fn add_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> {
|
pub fn add_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> {
|
||||||
let zero = self.zero_extension();
|
let zero = self.zero_extension();
|
||||||
let mut terms = terms.to_vec();
|
let mut terms = terms.to_vec();
|
||||||
if terms.len().is_odd() {
|
if terms.is_empty() {
|
||||||
|
return zero;
|
||||||
|
} else if terms.len() < 3 {
|
||||||
|
terms.resize(3, zero);
|
||||||
|
} else if terms.len().is_even() {
|
||||||
terms.push(zero);
|
terms.push(zero);
|
||||||
}
|
}
|
||||||
// We maintain two accumulators, one for the sum of even elements, and one for odd elements.
|
|
||||||
let mut acc0 = zero;
|
let mut acc = self.add_three_extension(terms[0], terms[1], terms[2]);
|
||||||
let mut acc1 = zero;
|
terms.drain(0..3);
|
||||||
for chunk in terms.chunks_exact(2) {
|
for chunk in terms.chunks_exact(2) {
|
||||||
(acc0, acc1) = self.add_two_extension(acc0, chunk[0], acc1, chunk[1]);
|
acc = self.add_three_extension(acc, chunk[0], chunk[1]);
|
||||||
}
|
}
|
||||||
// We sum both accumulators to get the final result.
|
acc
|
||||||
self.add_extension(acc0, acc1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn sub_extension(
|
pub fn sub_extension(
|
||||||
@ -146,7 +177,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
|||||||
b1: ExtensionTarget<D>,
|
b1: ExtensionTarget<D>,
|
||||||
) -> (ExtensionTarget<D>, ExtensionTarget<D>) {
|
) -> (ExtensionTarget<D>, ExtensionTarget<D>) {
|
||||||
let one = self.one_extension();
|
let one = self.one_extension();
|
||||||
self.double_arithmetic_extension(F::ONE, F::NEG_ONE, one, a0, b0, a1, b1)
|
self.double_arithmetic_extension(F::ONE, F::NEG_ONE, one, a0, b0, one, a1, b1)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn sub_ext_algebra(
|
pub fn sub_ext_algebra(
|
||||||
@ -184,6 +215,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
|||||||
zero,
|
zero,
|
||||||
zero,
|
zero,
|
||||||
zero,
|
zero,
|
||||||
|
zero,
|
||||||
)
|
)
|
||||||
.0
|
.0
|
||||||
}
|
}
|
||||||
@ -196,6 +228,18 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
|||||||
self.mul_extension_with_const(F::ONE, multiplicand_0, multiplicand_1)
|
self.mul_extension_with_const(F::ONE, multiplicand_0, multiplicand_1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns `(a0*b0, a1*b1)`.
|
||||||
|
pub fn mul_two_extension(
|
||||||
|
&mut self,
|
||||||
|
a0: ExtensionTarget<D>,
|
||||||
|
b0: ExtensionTarget<D>,
|
||||||
|
a1: ExtensionTarget<D>,
|
||||||
|
b1: ExtensionTarget<D>,
|
||||||
|
) -> (ExtensionTarget<D>, ExtensionTarget<D>) {
|
||||||
|
let zero = self.zero_extension();
|
||||||
|
self.double_arithmetic_extension(F::ONE, F::ZERO, a0, b0, zero, a1, b1, zero)
|
||||||
|
}
|
||||||
|
|
||||||
/// Computes `x^2`.
|
/// Computes `x^2`.
|
||||||
pub fn square_extension(&mut self, x: ExtensionTarget<D>) -> ExtensionTarget<D> {
|
pub fn square_extension(&mut self, x: ExtensionTarget<D>) -> ExtensionTarget<D> {
|
||||||
self.mul_extension(x, x)
|
self.mul_extension(x, x)
|
||||||
@ -221,12 +265,38 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
|||||||
ExtensionAlgebraTarget(res)
|
ExtensionAlgebraTarget(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Multiply 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s.
|
||||||
|
pub fn mul_three_extension(
|
||||||
|
&mut self,
|
||||||
|
a: ExtensionTarget<D>,
|
||||||
|
b: ExtensionTarget<D>,
|
||||||
|
c: ExtensionTarget<D>,
|
||||||
|
) -> ExtensionTarget<D> {
|
||||||
|
let zero = self.zero_extension();
|
||||||
|
let gate = self.num_gates();
|
||||||
|
let first_out =
|
||||||
|
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_first_output());
|
||||||
|
self.double_arithmetic_extension(F::ONE, F::ZERO, a, b, zero, c, first_out, zero)
|
||||||
|
.1
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Multiply `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s.
|
||||||
pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> {
|
pub fn mul_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> {
|
||||||
let mut product = self.one_extension();
|
let one = self.one_extension();
|
||||||
for term in terms {
|
let mut terms = terms.to_vec();
|
||||||
product = self.mul_extension(product, *term);
|
if terms.is_empty() {
|
||||||
|
return one;
|
||||||
|
} else if terms.len() < 3 {
|
||||||
|
terms.resize(3, one);
|
||||||
|
} else if terms.len().is_even() {
|
||||||
|
terms.push(one);
|
||||||
}
|
}
|
||||||
product
|
let mut acc = self.mul_three_extension(terms[0], terms[1], terms[2]);
|
||||||
|
terms.drain(0..3);
|
||||||
|
for chunk in terms.chunks_exact(2) {
|
||||||
|
acc = self.mul_three_extension(acc, chunk[0], chunk[1]);
|
||||||
|
}
|
||||||
|
acc
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Like `mul_add`, but for `ExtensionTarget`s.
|
/// Like `mul_add`, but for `ExtensionTarget`s.
|
||||||
@ -443,6 +513,43 @@ mod tests {
|
|||||||
use crate::verifier::verify;
|
use crate::verifier::verify;
|
||||||
use crate::witness::PartialWitness;
|
use crate::witness::PartialWitness;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mul_many() -> Result<()> {
|
||||||
|
type F = CrandallField;
|
||||||
|
type FF = QuarticCrandallField;
|
||||||
|
const D: usize = 4;
|
||||||
|
|
||||||
|
let config = CircuitConfig::large_config();
|
||||||
|
|
||||||
|
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||||
|
let mut pw = PartialWitness::new();
|
||||||
|
|
||||||
|
let vs = FF::rand_vec(3);
|
||||||
|
let ts = builder.add_virtual_extension_targets(3);
|
||||||
|
for (&v, &t) in vs.iter().zip(&ts) {
|
||||||
|
pw.set_extension_target(t, v);
|
||||||
|
}
|
||||||
|
let mul0 = builder.mul_many_extension(&ts);
|
||||||
|
let mul1 = {
|
||||||
|
let mut acc = builder.one_extension();
|
||||||
|
for &t in &ts {
|
||||||
|
acc = builder.mul_extension(acc, t);
|
||||||
|
}
|
||||||
|
acc
|
||||||
|
};
|
||||||
|
let mul2 = builder.mul_three_extension(ts[0], ts[1], ts[2]);
|
||||||
|
let mul3 = builder.constant_extension(vs.into_iter().product());
|
||||||
|
|
||||||
|
builder.assert_equal_extension(mul0, mul1);
|
||||||
|
builder.assert_equal_extension(mul1, mul2);
|
||||||
|
builder.assert_equal_extension(mul2, mul3);
|
||||||
|
|
||||||
|
let data = builder.build();
|
||||||
|
let proof = data.prove(pw)?;
|
||||||
|
|
||||||
|
verify(proof, &data.verifier_only, &data.common)
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_div_extension() -> Result<()> {
|
fn test_div_extension() -> Result<()> {
|
||||||
type F = CrandallField;
|
type F = CrandallField;
|
||||||
|
|||||||
@ -18,27 +18,30 @@ impl<const D: usize> ArithmeticExtensionGate<D> {
|
|||||||
GateRef::new(ArithmeticExtensionGate)
|
GateRef::new(ArithmeticExtensionGate)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn wires_fixed_multiplicand() -> Range<usize> {
|
pub fn wires_first_multiplicand_0() -> Range<usize> {
|
||||||
0..D
|
0..D
|
||||||
}
|
}
|
||||||
pub fn wires_multiplicand_0() -> Range<usize> {
|
pub fn wires_first_multiplicand_1() -> Range<usize> {
|
||||||
D..2 * D
|
D..2 * D
|
||||||
}
|
}
|
||||||
pub fn wires_addend_0() -> Range<usize> {
|
pub fn wires_first_addend() -> Range<usize> {
|
||||||
2 * D..3 * D
|
2 * D..3 * D
|
||||||
}
|
}
|
||||||
pub fn wires_multiplicand_1() -> Range<usize> {
|
pub fn wires_second_multiplicand_0() -> Range<usize> {
|
||||||
3 * D..4 * D
|
3 * D..4 * D
|
||||||
}
|
}
|
||||||
pub fn wires_addend_1() -> Range<usize> {
|
pub fn wires_second_multiplicand_1() -> Range<usize> {
|
||||||
4 * D..5 * D
|
4 * D..5 * D
|
||||||
}
|
}
|
||||||
pub fn wires_output_0() -> Range<usize> {
|
pub fn wires_second_addend() -> Range<usize> {
|
||||||
5 * D..6 * D
|
5 * D..6 * D
|
||||||
}
|
}
|
||||||
pub fn wires_output_1() -> Range<usize> {
|
pub fn wires_first_output() -> Range<usize> {
|
||||||
6 * D..7 * D
|
6 * D..7 * D
|
||||||
}
|
}
|
||||||
|
pub fn wires_second_output() -> Range<usize> {
|
||||||
|
7 * D..8 * D
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D> {
|
impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D> {
|
||||||
@ -50,21 +53,24 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
|
|||||||
let const_0 = vars.local_constants[0];
|
let const_0 = vars.local_constants[0];
|
||||||
let const_1 = vars.local_constants[1];
|
let const_1 = vars.local_constants[1];
|
||||||
|
|
||||||
let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand());
|
let first_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_0());
|
||||||
let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0());
|
let first_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_1());
|
||||||
let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0());
|
let first_addend = vars.get_local_ext_algebra(Self::wires_first_addend());
|
||||||
let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1());
|
let second_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_0());
|
||||||
let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1());
|
let second_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_1());
|
||||||
let output_0 = vars.get_local_ext_algebra(Self::wires_output_0());
|
let second_addend = vars.get_local_ext_algebra(Self::wires_second_addend());
|
||||||
let output_1 = vars.get_local_ext_algebra(Self::wires_output_1());
|
let first_output = vars.get_local_ext_algebra(Self::wires_first_output());
|
||||||
|
let second_output = vars.get_local_ext_algebra(Self::wires_second_output());
|
||||||
|
|
||||||
let computed_output_0 =
|
let first_computed_output = first_multiplicand_0 * first_multiplicand_1 * const_0.into()
|
||||||
fixed_multiplicand * multiplicand_0 * const_0.into() + addend_0 * const_1.into();
|
+ first_addend * const_1.into();
|
||||||
let computed_output_1 =
|
let second_computed_output = second_multiplicand_0 * second_multiplicand_1 * const_0.into()
|
||||||
fixed_multiplicand * multiplicand_1 * const_0.into() + addend_1 * const_1.into();
|
+ second_addend * const_1.into();
|
||||||
|
|
||||||
let mut constraints = (output_0 - computed_output_0).to_basefield_array().to_vec();
|
let mut constraints = (first_output - first_computed_output)
|
||||||
constraints.extend((output_1 - computed_output_1).to_basefield_array());
|
.to_basefield_array()
|
||||||
|
.to_vec();
|
||||||
|
constraints.extend((second_output - second_computed_output).to_basefield_array());
|
||||||
constraints
|
constraints
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,26 +82,32 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
|
|||||||
let const_0 = vars.local_constants[0];
|
let const_0 = vars.local_constants[0];
|
||||||
let const_1 = vars.local_constants[1];
|
let const_1 = vars.local_constants[1];
|
||||||
|
|
||||||
let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand());
|
let first_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_0());
|
||||||
let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0());
|
let first_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_1());
|
||||||
let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0());
|
let first_addend = vars.get_local_ext_algebra(Self::wires_first_addend());
|
||||||
let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1());
|
let second_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_0());
|
||||||
let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1());
|
let second_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_1());
|
||||||
let output_0 = vars.get_local_ext_algebra(Self::wires_output_0());
|
let second_addend = vars.get_local_ext_algebra(Self::wires_second_addend());
|
||||||
let output_1 = vars.get_local_ext_algebra(Self::wires_output_1());
|
let first_output = vars.get_local_ext_algebra(Self::wires_first_output());
|
||||||
|
let second_output = vars.get_local_ext_algebra(Self::wires_second_output());
|
||||||
|
|
||||||
let computed_output_0 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_0);
|
let first_computed_output =
|
||||||
let computed_output_0 = builder.scalar_mul_ext_algebra(const_0, computed_output_0);
|
builder.mul_ext_algebra(first_multiplicand_0, first_multiplicand_1);
|
||||||
let scaled_addend_0 = builder.scalar_mul_ext_algebra(const_1, addend_0);
|
let first_computed_output = builder.scalar_mul_ext_algebra(const_0, first_computed_output);
|
||||||
let computed_output_0 = builder.add_ext_algebra(computed_output_0, scaled_addend_0);
|
let first_scaled_addend = builder.scalar_mul_ext_algebra(const_1, first_addend);
|
||||||
|
let first_computed_output =
|
||||||
|
builder.add_ext_algebra(first_computed_output, first_scaled_addend);
|
||||||
|
|
||||||
let computed_output_1 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_1);
|
let second_computed_output =
|
||||||
let computed_output_1 = builder.scalar_mul_ext_algebra(const_0, computed_output_1);
|
builder.mul_ext_algebra(second_multiplicand_0, second_multiplicand_1);
|
||||||
let scaled_addend_1 = builder.scalar_mul_ext_algebra(const_1, addend_1);
|
let second_computed_output =
|
||||||
let computed_output_1 = builder.add_ext_algebra(computed_output_1, scaled_addend_1);
|
builder.scalar_mul_ext_algebra(const_0, second_computed_output);
|
||||||
|
let second_scaled_addend = builder.scalar_mul_ext_algebra(const_1, second_addend);
|
||||||
|
let second_computed_output =
|
||||||
|
builder.add_ext_algebra(second_computed_output, second_scaled_addend);
|
||||||
|
|
||||||
let diff_0 = builder.sub_ext_algebra(output_0, computed_output_0);
|
let diff_0 = builder.sub_ext_algebra(first_output, first_computed_output);
|
||||||
let diff_1 = builder.sub_ext_algebra(output_1, computed_output_1);
|
let diff_1 = builder.sub_ext_algebra(second_output, second_computed_output);
|
||||||
let mut constraints = diff_0.to_ext_target_array().to_vec();
|
let mut constraints = diff_0.to_ext_target_array().to_vec();
|
||||||
constraints.extend(diff_1.to_ext_target_array());
|
constraints.extend(diff_1.to_ext_target_array());
|
||||||
constraints
|
constraints
|
||||||
@ -120,7 +132,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn num_wires(&self) -> usize {
|
fn num_wires(&self) -> usize {
|
||||||
7 * D
|
8 * D
|
||||||
}
|
}
|
||||||
|
|
||||||
fn num_constants(&self) -> usize {
|
fn num_constants(&self) -> usize {
|
||||||
@ -150,9 +162,9 @@ struct ArithmeticExtensionGenerator1<F: Extendable<D>, const D: usize> {
|
|||||||
|
|
||||||
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensionGenerator0<F, D> {
|
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensionGenerator0<F, D> {
|
||||||
fn dependencies(&self) -> Vec<Target> {
|
fn dependencies(&self) -> Vec<Target> {
|
||||||
ArithmeticExtensionGate::<D>::wires_fixed_multiplicand()
|
ArithmeticExtensionGate::<D>::wires_first_multiplicand_0()
|
||||||
.chain(ArithmeticExtensionGate::<D>::wires_multiplicand_0())
|
.chain(ArithmeticExtensionGate::<D>::wires_first_multiplicand_1())
|
||||||
.chain(ArithmeticExtensionGate::<D>::wires_addend_0())
|
.chain(ArithmeticExtensionGate::<D>::wires_first_addend())
|
||||||
.map(|i| Target::wire(self.gate_index, i))
|
.map(|i| Target::wire(self.gate_index, i))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
@ -163,29 +175,29 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensio
|
|||||||
witness.get_extension_target(t)
|
witness.get_extension_target(t)
|
||||||
};
|
};
|
||||||
|
|
||||||
let fixed_multiplicand =
|
|
||||||
extract_extension(ArithmeticExtensionGate::<D>::wires_fixed_multiplicand());
|
|
||||||
let multiplicand_0 =
|
let multiplicand_0 =
|
||||||
extract_extension(ArithmeticExtensionGate::<D>::wires_multiplicand_0());
|
extract_extension(ArithmeticExtensionGate::<D>::wires_first_multiplicand_0());
|
||||||
let addend_0 = extract_extension(ArithmeticExtensionGate::<D>::wires_addend_0());
|
let multiplicand_1 =
|
||||||
|
extract_extension(ArithmeticExtensionGate::<D>::wires_first_multiplicand_1());
|
||||||
|
let addend = extract_extension(ArithmeticExtensionGate::<D>::wires_first_addend());
|
||||||
|
|
||||||
let output_target_0 = ExtensionTarget::from_range(
|
let output_target = ExtensionTarget::from_range(
|
||||||
self.gate_index,
|
self.gate_index,
|
||||||
ArithmeticExtensionGate::<D>::wires_output_0(),
|
ArithmeticExtensionGate::<D>::wires_first_output(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let computed_output_0 = fixed_multiplicand * multiplicand_0 * self.const_0.into()
|
let computed_output =
|
||||||
+ addend_0 * self.const_1.into();
|
multiplicand_0 * multiplicand_1 * self.const_0.into() + addend * self.const_1.into();
|
||||||
|
|
||||||
GeneratedValues::singleton_extension_target(output_target_0, computed_output_0)
|
GeneratedValues::singleton_extension_target(output_target, computed_output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensionGenerator1<F, D> {
|
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensionGenerator1<F, D> {
|
||||||
fn dependencies(&self) -> Vec<Target> {
|
fn dependencies(&self) -> Vec<Target> {
|
||||||
ArithmeticExtensionGate::<D>::wires_fixed_multiplicand()
|
ArithmeticExtensionGate::<D>::wires_second_multiplicand_0()
|
||||||
.chain(ArithmeticExtensionGate::<D>::wires_multiplicand_1())
|
.chain(ArithmeticExtensionGate::<D>::wires_second_multiplicand_1())
|
||||||
.chain(ArithmeticExtensionGate::<D>::wires_addend_1())
|
.chain(ArithmeticExtensionGate::<D>::wires_second_addend())
|
||||||
.map(|i| Target::wire(self.gate_index, i))
|
.map(|i| Target::wire(self.gate_index, i))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
@ -196,21 +208,21 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensio
|
|||||||
witness.get_extension_target(t)
|
witness.get_extension_target(t)
|
||||||
};
|
};
|
||||||
|
|
||||||
let fixed_multiplicand =
|
let multiplicand_0 =
|
||||||
extract_extension(ArithmeticExtensionGate::<D>::wires_fixed_multiplicand());
|
extract_extension(ArithmeticExtensionGate::<D>::wires_second_multiplicand_0());
|
||||||
let multiplicand_1 =
|
let multiplicand_1 =
|
||||||
extract_extension(ArithmeticExtensionGate::<D>::wires_multiplicand_1());
|
extract_extension(ArithmeticExtensionGate::<D>::wires_second_multiplicand_1());
|
||||||
let addend_1 = extract_extension(ArithmeticExtensionGate::<D>::wires_addend_1());
|
let addend = extract_extension(ArithmeticExtensionGate::<D>::wires_second_addend());
|
||||||
|
|
||||||
let output_target_1 = ExtensionTarget::from_range(
|
let output_target = ExtensionTarget::from_range(
|
||||||
self.gate_index,
|
self.gate_index,
|
||||||
ArithmeticExtensionGate::<D>::wires_output_1(),
|
ArithmeticExtensionGate::<D>::wires_second_output(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let computed_output_1 = fixed_multiplicand * multiplicand_1 * self.const_0.into()
|
let computed_output =
|
||||||
+ addend_1 * self.const_1.into();
|
multiplicand_0 * multiplicand_1 * self.const_0.into() + addend * self.const_1.into();
|
||||||
|
|
||||||
GeneratedValues::singleton_extension_target(output_target_1, computed_output_1)
|
GeneratedValues::singleton_extension_target(output_target, computed_output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -122,8 +122,10 @@ impl<const D: usize> ReducingFactorTarget<D> {
|
|||||||
// out_0 = alpha acc + pair[0]
|
// out_0 = alpha acc + pair[0]
|
||||||
// acc' = out_1 = alpha out_0 + pair[1]
|
// acc' = out_1 = alpha out_0 + pair[1]
|
||||||
let gate = builder.num_gates();
|
let gate = builder.num_gates();
|
||||||
let out_0 =
|
let out_0 = ExtensionTarget::from_range(
|
||||||
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_output_0());
|
gate,
|
||||||
|
ArithmeticExtensionGate::<D>::wires_first_output(),
|
||||||
|
);
|
||||||
acc = builder
|
acc = builder
|
||||||
.double_arithmetic_extension(
|
.double_arithmetic_extension(
|
||||||
F::ONE,
|
F::ONE,
|
||||||
@ -131,6 +133,7 @@ impl<const D: usize> ReducingFactorTarget<D> {
|
|||||||
self.base,
|
self.base,
|
||||||
acc,
|
acc,
|
||||||
pair[0],
|
pair[0],
|
||||||
|
self.base,
|
||||||
out_0,
|
out_0,
|
||||||
pair[1],
|
pair[1],
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user