This commit is contained in:
wborgeaud 2021-06-25 17:24:22 +02:00
parent 2f06a78cb1
commit 636d8bef07
3 changed files with 50 additions and 29 deletions

View File

@ -97,7 +97,10 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
a: ExtensionAlgebraTarget<D>,
b: ExtensionAlgebraTarget<D>,
) -> ExtensionAlgebraTarget<D> {
// We run two additions in parallel. So `[a0,a1,a2,a3] + [b0,b1,b2,b3]` is computed with two
// `add_two_extension`, first `[a0,a1]+[b0,b1]` then `[a2,a3]+[b2,b3]`.
let mut res = Vec::with_capacity(D);
// We need some extra logic if D is odd.
let d_even = D & (D ^ 1); // = 2 * (D/2)
for mut chunk in &(0..d_even).chunks(2) {
let i = chunk.next().unwrap();
@ -117,11 +120,13 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
if terms.len() % 2 == 1 {
terms.push(zero);
}
// We maintain two accumulators, one for the sum of even elements, and one for odd elements.
let mut acc0 = zero;
let mut acc1 = zero;
for chunk in terms.chunks_exact(2) {
(acc0, acc1) = self.add_two_extension(acc0, chunk[0], acc1, chunk[1]);
}
// We sum both accumulators to get the final result.
self.add_extension(acc0, acc1)
}
@ -150,6 +155,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
a: ExtensionAlgebraTarget<D>,
b: ExtensionAlgebraTarget<D>,
) -> ExtensionAlgebraTarget<D> {
// See `add_ext_algebra`.
let mut res = Vec::with_capacity(D);
let d_even = D & (D ^ 1); // = 2 * (D/2)
for mut chunk in &(0..d_even).chunks(2) {
@ -319,6 +325,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
denominator: y,
quotient: multiplicand_0,
});
// We need to zero out the other wires for the `ArithmeticExtensionGenerator` to hit.
self.add_generator(ZeroOutGenerator {
gate_index: gate,
ranges: vec![

View File

@ -106,7 +106,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
gate_index: usize,
local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
let gen = MulExtensionGenerator {
let gen = ArithmeticExtensionGenerator {
gate_index,
const_0: local_constants[0],
const_1: local_constants[1],
@ -131,13 +131,13 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
}
}
struct MulExtensionGenerator<F: Extendable<D>, const D: usize> {
struct ArithmeticExtensionGenerator<F: Extendable<D>, const D: usize> {
gate_index: usize,
const_0: F,
const_1: F,
}
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for MulExtensionGenerator<F, D> {
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensionGenerator<F, D> {
fn dependencies(&self) -> Vec<Target> {
ArithmeticExtensionGate::<D>::wires_fixed_multiplicand()
.chain(ArithmeticExtensionGate::<D>::wires_multiplicand_0())

View File

@ -91,6 +91,11 @@ impl<const D: usize> ReducingFactorTarget<D> {
Self { base, count: 0 }
}
/// Reduces a length `n` vector of `ExtensionTarget`s using `n/2` `ArithmeticExtensionGate`s.
/// It does this by running two accumulators in parallel. Here's an example with `n=4, alpha=2, D=1`:
/// 1st gate: 2 0 4 11 2 4 24 <- 2*0+4= 4, 2*11+2=24
/// 2nd gate: 2 4 3 24 1 11 49 <- 2*4+3=11, 2*24+1=49
/// which verifies that `2.reduce([1,2,3,4]) = 49`.
pub fn reduce<F>(
&mut self,
iter: &[ExtensionTarget<D>], // Could probably work with a `DoubleEndedIterator` too.
@ -99,18 +104,22 @@ impl<const D: usize> ReducingFactorTarget<D> {
where
F: Extendable<D>,
{
let zero = builder.zero_extension();
let l = iter.len();
self.count += l as u64;
// If needed we pad the original vector so that it has even length.
let padded_iter = if l % 2 == 0 {
iter.to_vec()
} else {
[iter, &[builder.zero_extension()]].concat()
[iter, &[zero]].concat()
};
let half_length = padded_iter.len() / 2;
// Add `n/2` `ArithmeticExtensionGate`s that will perform the accumulation.
let gates = (0..half_length)
.map(|_| builder.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ONE]))
.collect::<Vec<_>>();
// Add a generator that will fill the accumulation wires.
builder.add_generator(ParallelReductionGenerator {
base: self.base,
padded_iter: padded_iter.clone(),
@ -119,24 +128,33 @@ impl<const D: usize> ReducingFactorTarget<D> {
});
for i in 0..half_length {
// The fixed multiplicand is always `base`.
builder.route_extension(
self.base,
ExtensionTarget::from_range(
gates[i],
ArithmeticExtensionGate::<D>::wires_fixed_multiplicand(),
),
);
// Set the addends for the first half of the accumulation.
builder.route_extension(
padded_iter[2 * half_length - i - 1],
ExtensionTarget::from_range(
gates[i],
ArithmeticExtensionGate::<D>::wires_addend_0(),
),
padded_iter[2 * half_length - i - 1],
);
}
for i in 0..half_length {
// Set the addends for the second half of the accumulation.
builder.route_extension(
padded_iter[half_length - i - 1],
ExtensionTarget::from_range(
gates[i],
ArithmeticExtensionGate::<D>::wires_addend_1(),
),
padded_iter[half_length - i - 1],
);
}
for gate_pair in gates[..half_length].windows(2) {
// Verifies that the accumulator is passed between gates for the first half of the accumulation.
builder.assert_equal_extension(
ExtensionTarget::from_range(
gate_pair[0],
@ -149,6 +167,7 @@ impl<const D: usize> ReducingFactorTarget<D> {
);
}
for gate_pair in gates[half_length..].windows(2) {
// Verifies that the accumulator is passed between gates for the second half of the accumulation.
builder.assert_equal_extension(
ExtensionTarget::from_range(
gate_pair[0],
@ -160,6 +179,16 @@ impl<const D: usize> ReducingFactorTarget<D> {
),
);
}
// Verifies that the starting accumulator for the first half is zero.
builder.assert_equal_extension(
ExtensionTarget::from_range(
gates[0],
ArithmeticExtensionGate::<D>::wires_multiplicand_0(),
),
zero,
);
// Verifies that the final accumulator for the first half is passed as a starting
// accumulator for the second half.
builder.assert_equal_extension(
ExtensionTarget::from_range(
gates[half_length - 1],
@ -171,6 +200,7 @@ impl<const D: usize> ReducingFactorTarget<D> {
),
);
// Return the final accumulator for the second half.
ExtensionTarget::from_range(
gates[half_length - 1],
ArithmeticExtensionGate::<D>::wires_output_1(),
@ -206,6 +236,7 @@ impl<const D: usize> ReducingFactorTarget<D> {
}
}
/// Fills the intermediate accumulator in `ReducingFactorTarget::reduce`.
struct ParallelReductionGenerator<const D: usize> {
base: ExtensionTarget<D>,
padded_iter: Vec<ExtensionTarget<D>>,
@ -215,6 +246,7 @@ struct ParallelReductionGenerator<const D: usize> {
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ParallelReductionGenerator<D> {
fn dependencies(&self) -> Vec<Target> {
// Need only the values and the base.
self.padded_iter
.iter()
.flat_map(|ext| ext.to_target_array())
@ -230,6 +262,7 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ParallelReductionG
.iter()
.map(|&ext| witness.get_extension_target(ext))
.collect::<Vec<_>>();
// Computed the intermediate accumulators.
let intermediate_accs = vs
.iter()
.rev()
@ -240,13 +273,7 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ParallelReductionG
})
.collect::<Vec<_>>();
for i in 0..self.half_length {
pw.set_extension_target(
ExtensionTarget::from_range(
self.gates[i],
ArithmeticExtensionGate::<D>::wires_fixed_multiplicand(),
),
base,
);
// Fill the accumulators for the first half.
pw.set_extension_target(
ExtensionTarget::from_range(
self.gates[i],
@ -254,13 +281,7 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ParallelReductionG
),
intermediate_accs[i],
);
pw.set_extension_target(
ExtensionTarget::from_range(
self.gates[i],
ArithmeticExtensionGate::<D>::wires_addend_0(),
),
vs[2 * self.half_length - i - 1],
);
// Fill the accumulators for the second half.
pw.set_extension_target(
ExtensionTarget::from_range(
self.gates[i],
@ -268,13 +289,6 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ParallelReductionG
),
intermediate_accs[self.half_length + i],
);
pw.set_extension_target(
ExtensionTarget::from_range(
self.gates[i],
ArithmeticExtensionGate::<D>::wires_addend_1(),
),
vs[self.half_length - i - 1],
);
}
pw