Modify ArithmeticExtensionGate to support 32 wires

This commit is contained in:
wborgeaud 2021-07-21 17:20:08 +02:00
parent 8642a10fde
commit b59d497964
3 changed files with 158 additions and 100 deletions

View File

@ -17,37 +17,47 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
&mut self,
const_0: F,
const_1: F,
fixed_multiplicand: ExtensionTarget<D>,
multiplicand_0: ExtensionTarget<D>,
addend_0: ExtensionTarget<D>,
multiplicand_1: ExtensionTarget<D>,
addend_1: ExtensionTarget<D>,
first_multiplicand_0: ExtensionTarget<D>,
first_multiplicand_1: ExtensionTarget<D>,
first_addend: ExtensionTarget<D>,
second_multiplicand_0: ExtensionTarget<D>,
second_multiplicand_1: ExtensionTarget<D>,
second_addend: ExtensionTarget<D>,
) -> (ExtensionTarget<D>, ExtensionTarget<D>) {
let gate = self.add_gate(ArithmeticExtensionGate::new(), vec![const_0, const_1]);
let wire_fixed_multiplicand = ExtensionTarget::from_range(
let wire_first_multiplicand_0 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_fixed_multiplicand(),
ArithmeticExtensionGate::<D>::wires_first_multiplicand_0(),
);
let wire_multiplicand_0 =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_multiplicand_0());
let wire_addend_0 =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_addend_0());
let wire_multiplicand_1 =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_multiplicand_1());
let wire_addend_1 =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_addend_1());
let wire_output_0 =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_output_0());
let wire_output_1 =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_output_1());
let wire_first_multiplicand_1 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_first_multiplicand_1(),
);
let wire_first_addend =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_first_addend());
let wire_second_multiplicand_0 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_second_multiplicand_0(),
);
let wire_second_multiplicand_1 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_second_multiplicand_1(),
);
let wire_second_addend =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_second_addend());
let wire_first_output =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_first_output());
let wire_second_output =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_second_output());
self.route_extension(fixed_multiplicand, wire_fixed_multiplicand);
self.route_extension(multiplicand_0, wire_multiplicand_0);
self.route_extension(addend_0, wire_addend_0);
self.route_extension(multiplicand_1, wire_multiplicand_1);
self.route_extension(addend_1, wire_addend_1);
(wire_output_0, wire_output_1)
self.route_extension(first_multiplicand_0, wire_first_multiplicand_0);
self.route_extension(first_multiplicand_1, wire_first_multiplicand_1);
self.route_extension(first_addend, wire_first_addend);
self.route_extension(second_multiplicand_0, wire_second_multiplicand_0);
self.route_extension(second_multiplicand_1, wire_second_multiplicand_1);
self.route_extension(second_addend, wire_second_addend);
(wire_first_output, wire_second_output)
}
pub fn arithmetic_extension(
@ -67,6 +77,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
addend,
zero,
zero,
zero,
)
.0
}
@ -89,7 +100,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
b1: ExtensionTarget<D>,
) -> (ExtensionTarget<D>, ExtensionTarget<D>) {
let one = self.one_extension();
self.double_arithmetic_extension(F::ONE, F::ONE, one, a0, b0, a1, b1)
self.double_arithmetic_extension(F::ONE, F::ONE, one, a0, b0, one, a1, b1)
}
pub fn add_ext_algebra(
@ -147,7 +158,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
b1: ExtensionTarget<D>,
) -> (ExtensionTarget<D>, ExtensionTarget<D>) {
let one = self.one_extension();
self.double_arithmetic_extension(F::ONE, F::NEG_ONE, one, a0, b0, a1, b1)
self.double_arithmetic_extension(F::ONE, F::NEG_ONE, one, a0, b0, one, a1, b1)
}
pub fn sub_ext_algebra(
@ -185,6 +196,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
zero,
zero,
zero,
zero,
)
.0
}
@ -205,7 +217,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
a1: ExtensionTarget<D>,
b1: ExtensionTarget<D>,
) -> (ExtensionTarget<D>, ExtensionTarget<D>) {
todo!()
let zero = self.zero_extension();
self.double_arithmetic_extension(F::ONE, F::ZERO, a0, b0, zero, a1, b1, zero)
}
/// Computes `x^2`.
@ -239,19 +252,14 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
if terms.len().is_odd() {
terms.push(one);
}
// We maintain two accumulators, one for the sum of even elements, and one for odd elements.
// We maintain two accumulators, one for the product of even elements, and one for odd elements.
let mut acc0 = one;
let mut acc1 = one;
for chunk in terms.chunks_exact(2) {
(acc0, acc1) = self.mul_two_extension(acc0, chunk[0], acc1, chunk[1]);
}
// We sum both accumulators to get the final result.
self.add_extension(acc0, acc1)
let mut product = self.one_extension();
for term in terms {
product = self.mul_extension(product, *term);
}
product
// We multiply both accumulators to get the final result.
self.mul_extension(acc0, acc1)
}
/// Like `mul_add`, but for `ExtensionTarget`s.
@ -468,6 +476,41 @@ mod tests {
use crate::verifier::verify;
use crate::witness::PartialWitness;
#[test]
fn test_mul_many() -> Result<()> {
type F = CrandallField;
type FF = QuarticCrandallField;
const D: usize = 4;
let config = CircuitConfig::large_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let mut pw = PartialWitness::new();
let vs = FF::rand_vec(20);
let ts = builder.add_virtual_extension_targets(20);
for (&v, &t) in vs.iter().zip(&ts) {
pw.set_extension_target(t, v);
}
let mul0 = builder.mul_many_extension(&ts);
let mul1 = {
let mut acc = builder.one_extension();
for &t in &ts {
acc = builder.mul_extension(acc, t);
}
acc
};
let mul2 = builder.constant_extension(vs.into_iter().product());
builder.assert_equal_extension(mul0, mul1);
builder.assert_equal_extension(mul1, mul2);
let data = builder.build();
let proof = data.prove(pw)?;
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_div_extension() -> Result<()> {
type F = CrandallField;

View File

@ -18,27 +18,30 @@ impl<const D: usize> ArithmeticExtensionGate<D> {
GateRef::new(ArithmeticExtensionGate)
}
pub fn wires_fixed_multiplicand() -> Range<usize> {
pub fn wires_first_multiplicand_0() -> Range<usize> {
0..D
}
pub fn wires_multiplicand_0() -> Range<usize> {
pub fn wires_first_multiplicand_1() -> Range<usize> {
D..2 * D
}
pub fn wires_addend_0() -> Range<usize> {
pub fn wires_first_addend() -> Range<usize> {
2 * D..3 * D
}
pub fn wires_multiplicand_1() -> Range<usize> {
pub fn wires_second_multiplicand_0() -> Range<usize> {
3 * D..4 * D
}
pub fn wires_addend_1() -> Range<usize> {
pub fn wires_second_multiplicand_1() -> Range<usize> {
4 * D..5 * D
}
pub fn wires_output_0() -> Range<usize> {
pub fn wires_second_addend() -> Range<usize> {
5 * D..6 * D
}
pub fn wires_output_1() -> Range<usize> {
pub fn wires_first_output() -> Range<usize> {
6 * D..7 * D
}
pub fn wires_second_output() -> Range<usize> {
7 * D..8 * D
}
}
impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D> {
@ -50,21 +53,24 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
let const_0 = vars.local_constants[0];
let const_1 = vars.local_constants[1];
let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand());
let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0());
let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0());
let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1());
let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1());
let output_0 = vars.get_local_ext_algebra(Self::wires_output_0());
let output_1 = vars.get_local_ext_algebra(Self::wires_output_1());
let first_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_0());
let first_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_1());
let first_addend = vars.get_local_ext_algebra(Self::wires_first_addend());
let second_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_0());
let second_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_1());
let second_addend = vars.get_local_ext_algebra(Self::wires_second_addend());
let first_output = vars.get_local_ext_algebra(Self::wires_first_output());
let second_output = vars.get_local_ext_algebra(Self::wires_second_output());
let computed_output_0 =
fixed_multiplicand * multiplicand_0 * const_0.into() + addend_0 * const_1.into();
let computed_output_1 =
fixed_multiplicand * multiplicand_1 * const_0.into() + addend_1 * const_1.into();
let first_computed_output = first_multiplicand_0 * first_multiplicand_1 * const_0.into()
+ first_addend * const_1.into();
let second_computed_output = second_multiplicand_0 * second_multiplicand_1 * const_0.into()
+ second_addend * const_1.into();
let mut constraints = (output_0 - computed_output_0).to_basefield_array().to_vec();
constraints.extend((output_1 - computed_output_1).to_basefield_array());
let mut constraints = (first_output - first_computed_output)
.to_basefield_array()
.to_vec();
constraints.extend((second_output - second_computed_output).to_basefield_array());
constraints
}
@ -76,26 +82,32 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
let const_0 = vars.local_constants[0];
let const_1 = vars.local_constants[1];
let fixed_multiplicand = vars.get_local_ext_algebra(Self::wires_fixed_multiplicand());
let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_multiplicand_0());
let addend_0 = vars.get_local_ext_algebra(Self::wires_addend_0());
let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_multiplicand_1());
let addend_1 = vars.get_local_ext_algebra(Self::wires_addend_1());
let output_0 = vars.get_local_ext_algebra(Self::wires_output_0());
let output_1 = vars.get_local_ext_algebra(Self::wires_output_1());
let first_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_0());
let first_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_1());
let first_addend = vars.get_local_ext_algebra(Self::wires_first_addend());
let second_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_0());
let second_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_1());
let second_addend = vars.get_local_ext_algebra(Self::wires_second_addend());
let first_output = vars.get_local_ext_algebra(Self::wires_first_output());
let second_output = vars.get_local_ext_algebra(Self::wires_second_output());
let computed_output_0 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_0);
let computed_output_0 = builder.scalar_mul_ext_algebra(const_0, computed_output_0);
let scaled_addend_0 = builder.scalar_mul_ext_algebra(const_1, addend_0);
let computed_output_0 = builder.add_ext_algebra(computed_output_0, scaled_addend_0);
let first_computed_output =
builder.mul_ext_algebra(first_multiplicand_0, first_multiplicand_1);
let first_computed_output = builder.scalar_mul_ext_algebra(const_0, first_computed_output);
let first_scaled_addend = builder.scalar_mul_ext_algebra(const_1, first_addend);
let first_computed_output =
builder.add_ext_algebra(first_computed_output, first_scaled_addend);
let computed_output_1 = builder.mul_ext_algebra(fixed_multiplicand, multiplicand_1);
let computed_output_1 = builder.scalar_mul_ext_algebra(const_0, computed_output_1);
let scaled_addend_1 = builder.scalar_mul_ext_algebra(const_1, addend_1);
let computed_output_1 = builder.add_ext_algebra(computed_output_1, scaled_addend_1);
let second_computed_output =
builder.mul_ext_algebra(second_multiplicand_0, second_multiplicand_1);
let second_computed_output =
builder.scalar_mul_ext_algebra(const_0, second_computed_output);
let second_scaled_addend = builder.scalar_mul_ext_algebra(const_1, second_addend);
let second_computed_output =
builder.add_ext_algebra(second_computed_output, second_scaled_addend);
let diff_0 = builder.sub_ext_algebra(output_0, computed_output_0);
let diff_1 = builder.sub_ext_algebra(output_1, computed_output_1);
let diff_0 = builder.sub_ext_algebra(first_output, first_computed_output);
let diff_1 = builder.sub_ext_algebra(second_output, second_computed_output);
let mut constraints = diff_0.to_ext_target_array().to_vec();
constraints.extend(diff_1.to_ext_target_array());
constraints
@ -120,7 +132,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
}
fn num_wires(&self) -> usize {
7 * D
8 * D
}
fn num_constants(&self) -> usize {
@ -150,9 +162,9 @@ struct ArithmeticExtensionGenerator1<F: Extendable<D>, const D: usize> {
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensionGenerator0<F, D> {
fn dependencies(&self) -> Vec<Target> {
ArithmeticExtensionGate::<D>::wires_fixed_multiplicand()
.chain(ArithmeticExtensionGate::<D>::wires_multiplicand_0())
.chain(ArithmeticExtensionGate::<D>::wires_addend_0())
ArithmeticExtensionGate::<D>::wires_first_multiplicand_0()
.chain(ArithmeticExtensionGate::<D>::wires_first_multiplicand_1())
.chain(ArithmeticExtensionGate::<D>::wires_first_addend())
.map(|i| Target::wire(self.gate_index, i))
.collect()
}
@ -163,29 +175,29 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensio
witness.get_extension_target(t)
};
let fixed_multiplicand =
extract_extension(ArithmeticExtensionGate::<D>::wires_fixed_multiplicand());
let multiplicand_0 =
extract_extension(ArithmeticExtensionGate::<D>::wires_multiplicand_0());
let addend_0 = extract_extension(ArithmeticExtensionGate::<D>::wires_addend_0());
extract_extension(ArithmeticExtensionGate::<D>::wires_first_multiplicand_0());
let multiplicand_1 =
extract_extension(ArithmeticExtensionGate::<D>::wires_first_multiplicand_1());
let addend = extract_extension(ArithmeticExtensionGate::<D>::wires_first_addend());
let output_target_0 = ExtensionTarget::from_range(
let output_target = ExtensionTarget::from_range(
self.gate_index,
ArithmeticExtensionGate::<D>::wires_output_0(),
ArithmeticExtensionGate::<D>::wires_first_output(),
);
let computed_output_0 = fixed_multiplicand * multiplicand_0 * self.const_0.into()
+ addend_0 * self.const_1.into();
let computed_output =
multiplicand_0 * multiplicand_1 * self.const_0.into() + addend * self.const_1.into();
PartialWitness::singleton_extension_target(output_target_0, computed_output_0)
PartialWitness::singleton_extension_target(output_target, computed_output)
}
}
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensionGenerator1<F, D> {
fn dependencies(&self) -> Vec<Target> {
ArithmeticExtensionGate::<D>::wires_fixed_multiplicand()
.chain(ArithmeticExtensionGate::<D>::wires_multiplicand_1())
.chain(ArithmeticExtensionGate::<D>::wires_addend_1())
ArithmeticExtensionGate::<D>::wires_second_multiplicand_0()
.chain(ArithmeticExtensionGate::<D>::wires_second_multiplicand_1())
.chain(ArithmeticExtensionGate::<D>::wires_second_addend())
.map(|i| Target::wire(self.gate_index, i))
.collect()
}
@ -196,21 +208,21 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensio
witness.get_extension_target(t)
};
let fixed_multiplicand =
extract_extension(ArithmeticExtensionGate::<D>::wires_fixed_multiplicand());
let multiplicand_0 =
extract_extension(ArithmeticExtensionGate::<D>::wires_second_multiplicand_0());
let multiplicand_1 =
extract_extension(ArithmeticExtensionGate::<D>::wires_multiplicand_1());
let addend_1 = extract_extension(ArithmeticExtensionGate::<D>::wires_addend_1());
extract_extension(ArithmeticExtensionGate::<D>::wires_second_multiplicand_1());
let addend = extract_extension(ArithmeticExtensionGate::<D>::wires_second_addend());
let output_target_1 = ExtensionTarget::from_range(
let output_target = ExtensionTarget::from_range(
self.gate_index,
ArithmeticExtensionGate::<D>::wires_output_1(),
ArithmeticExtensionGate::<D>::wires_second_output(),
);
let computed_output_1 = fixed_multiplicand * multiplicand_1 * self.const_0.into()
+ addend_1 * self.const_1.into();
let computed_output =
multiplicand_0 * multiplicand_1 * self.const_0.into() + addend * self.const_1.into();
PartialWitness::singleton_extension_target(output_target_1, computed_output_1)
PartialWitness::singleton_extension_target(output_target, computed_output)
}
}

View File

@ -122,8 +122,10 @@ impl<const D: usize> ReducingFactorTarget<D> {
// out_0 = alpha acc + pair[0]
// acc' = out_1 = alpha out_0 + pair[1]
let gate = builder.num_gates();
let out_0 =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_output_0());
let out_0 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_first_output(),
);
acc = builder
.double_arithmetic_extension(
F::ONE,
@ -131,6 +133,7 @@ impl<const D: usize> ReducingFactorTarget<D> {
self.base,
acc,
pair[0],
self.base,
out_0,
pair[1],
)