Simple reduce (#78)

* Simple reduce

* Fix bug causing test failure
This commit is contained in:
Daniel Lubarov 2021-06-29 12:33:11 -07:00 committed by GitHub
parent 9a352193ed
commit f1e3474fcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 97 additions and 177 deletions

View File

@ -62,6 +62,10 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
pub fn num_gates(&self) -> usize {
self.gate_instances.len()
}
pub fn add_public_input(&mut self) -> Target {
let index = self.public_input_index;
self.public_input_index += 1;

View File

@ -2,6 +2,7 @@ use std::convert::TryInto;
use std::ops::Range;
use itertools::Itertools;
use num::Integer;
use crate::circuit_builder::CircuitBuilder;
use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget};
@ -108,7 +109,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let (o0, o1) = self.add_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]);
res.extend([o0, o1]);
}
if D % 2 == 1 {
if D.is_odd() {
res.push(self.add_extension(a.0[D - 1], b.0[D - 1]));
}
ExtensionAlgebraTarget(res.try_into().unwrap())
@ -117,7 +118,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn add_many_extension(&mut self, terms: &[ExtensionTarget<D>]) -> ExtensionTarget<D> {
let zero = self.zero_extension();
let mut terms = terms.to_vec();
if terms.len() % 2 == 1 {
if terms.len().is_odd() {
terms.push(zero);
}
// We maintain two accumulators, one for the sum of even elements, and one for odd elements.
@ -164,7 +165,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let (o0, o1) = self.sub_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]);
res.extend([o0, o1]);
}
if D % 2 == 1 {
if D.is_odd() {
res.push(self.sub_extension(a.0[D - 1], b.0[D - 1]));
}
ExtensionAlgebraTarget(res.try_into().unwrap())

View File

@ -106,12 +106,17 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
gate_index: usize,
local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
let gen = ArithmeticExtensionGenerator {
let gen0 = ArithmeticExtensionGenerator0 {
gate_index,
const_0: local_constants[0],
const_1: local_constants[1],
};
vec![Box::new(gen)]
let gen1 = ArithmeticExtensionGenerator1 {
gate_index,
const_0: local_constants[0],
const_1: local_constants[1],
};
vec![Box::new(gen0), Box::new(gen1)]
}
fn num_wires(&self) -> usize {
@ -131,19 +136,23 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
}
}
struct ArithmeticExtensionGenerator<F: Extendable<D>, const D: usize> {
struct ArithmeticExtensionGenerator0<F: Extendable<D>, const D: usize> {
gate_index: usize,
const_0: F,
const_1: F,
}
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensionGenerator<F, D> {
struct ArithmeticExtensionGenerator1<F: Extendable<D>, const D: usize> {
gate_index: usize,
const_0: F,
const_1: F,
}
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensionGenerator0<F, D> {
fn dependencies(&self) -> Vec<Target> {
ArithmeticExtensionGate::<D>::wires_fixed_multiplicand()
.chain(ArithmeticExtensionGate::<D>::wires_multiplicand_0())
.chain(ArithmeticExtensionGate::<D>::wires_addend_0())
.chain(ArithmeticExtensionGate::<D>::wires_multiplicand_1())
.chain(ArithmeticExtensionGate::<D>::wires_addend_1())
.map(|i| Target::wire(self.gate_index, i))
.collect()
}
@ -159,28 +168,49 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensio
let multiplicand_0 =
extract_extension(ArithmeticExtensionGate::<D>::wires_multiplicand_0());
let addend_0 = extract_extension(ArithmeticExtensionGate::<D>::wires_addend_0());
let multiplicand_1 =
extract_extension(ArithmeticExtensionGate::<D>::wires_multiplicand_1());
let addend_1 = extract_extension(ArithmeticExtensionGate::<D>::wires_addend_1());
let output_target_0 = ExtensionTarget::from_range(
self.gate_index,
ArithmeticExtensionGate::<D>::wires_output_0(),
);
let computed_output_0 = fixed_multiplicand * multiplicand_0 * self.const_0.into()
+ addend_0 * self.const_1.into();
PartialWitness::singleton_extension_target(output_target_0, computed_output_0)
}
}
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensionGenerator1<F, D> {
fn dependencies(&self) -> Vec<Target> {
ArithmeticExtensionGate::<D>::wires_fixed_multiplicand()
.chain(ArithmeticExtensionGate::<D>::wires_multiplicand_1())
.chain(ArithmeticExtensionGate::<D>::wires_addend_1())
.map(|i| Target::wire(self.gate_index, i))
.collect()
}
fn run_once(&self, witness: &PartialWitness<F>) -> PartialWitness<F> {
let extract_extension = |range: Range<usize>| -> F::Extension {
let t = ExtensionTarget::from_range(self.gate_index, range);
witness.get_extension_target(t)
};
let fixed_multiplicand =
extract_extension(ArithmeticExtensionGate::<D>::wires_fixed_multiplicand());
let multiplicand_1 =
extract_extension(ArithmeticExtensionGate::<D>::wires_multiplicand_1());
let addend_1 = extract_extension(ArithmeticExtensionGate::<D>::wires_addend_1());
let output_target_1 = ExtensionTarget::from_range(
self.gate_index,
ArithmeticExtensionGate::<D>::wires_output_1(),
);
let computed_output_0 = fixed_multiplicand * multiplicand_0 * self.const_0.into()
+ addend_0 * self.const_1.into();
let computed_output_1 = fixed_multiplicand * multiplicand_1 * self.const_0.into()
+ addend_1 * self.const_1.into();
let mut pw = PartialWitness::new();
pw.set_extension_target(output_target_0, computed_output_0);
pw.set_extension_target(output_target_1, computed_output_1);
pw
PartialWitness::singleton_extension_target(output_target_1, computed_output_1)
}
}

View File

@ -1,14 +1,13 @@
use std::borrow::Borrow;
use num::Integer;
use crate::circuit_builder::CircuitBuilder;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, Frobenius};
use crate::field::field::Field;
use crate::gates::arithmetic::ArithmeticExtensionGate;
use crate::generator::SimpleGenerator;
use crate::polynomial::polynomial::PolynomialCoeffs;
use crate::target::Target;
use crate::witness::PartialWitness;
/// When verifying the composition polynomial in FRI we have to compute sums of the form
/// `(sum_0^k a^i * x_i)/d_0 + (sum_k^r a^i * y_i)/d_1`
@ -98,113 +97,45 @@ impl<const D: usize> ReducingFactorTarget<D> {
/// which verifies that `2.reduce([1,2,3,4]) = 49`.
pub fn reduce<F>(
&mut self,
iter: &[ExtensionTarget<D>], // Could probably work with a `DoubleEndedIterator` too.
terms: &[ExtensionTarget<D>], // Could probably work with a `DoubleEndedIterator` too.
builder: &mut CircuitBuilder<F, D>,
) -> ExtensionTarget<D>
where
F: Extendable<D>,
{
let zero = builder.zero_extension();
let l = iter.len();
let l = terms.len();
self.count += l as u64;
// If needed we pad the original vector so that it has even length.
let padded_iter = if l % 2 == 0 {
iter.to_vec()
} else {
[iter, &[zero]].concat()
};
let half_length = padded_iter.len() / 2;
// Add `n/2` `ArithmeticExtensionGate`s that will perform the accumulation.
let gates = (0..half_length)
.map(|_| builder.add_gate(ArithmeticExtensionGate::new(), vec![F::ONE, F::ONE]))
.collect::<Vec<_>>();
// Add a generator that will fill the accumulation wires.
builder.add_generator(ParallelReductionGenerator {
base: self.base,
padded_iter: padded_iter.clone(),
gates: gates.clone(),
half_length,
});
let mut terms_vec = terms.to_vec();
// If needed, we pad the original vector so that it has even length.
if terms_vec.len().is_odd() {
terms_vec.push(zero);
}
terms_vec.reverse();
for i in 0..half_length {
// The fixed multiplicand is always `base`.
builder.route_extension(
self.base,
ExtensionTarget::from_range(
gates[i],
ArithmeticExtensionGate::<D>::wires_fixed_multiplicand(),
),
);
// Set the addends for the first half of the accumulation.
builder.route_extension(
padded_iter[2 * half_length - i - 1],
ExtensionTarget::from_range(
gates[i],
ArithmeticExtensionGate::<D>::wires_addend_0(),
),
);
// Set the addends for the second half of the accumulation.
builder.route_extension(
padded_iter[half_length - i - 1],
ExtensionTarget::from_range(
gates[i],
ArithmeticExtensionGate::<D>::wires_addend_1(),
),
);
let mut acc = zero;
for pair in terms_vec.chunks(2) {
// We will route the output of the first arithmetic operation to the multiplicand of the
// second, i.e. we compute the following:
// out_0 = alpha acc + pair[0]
// acc' = out_1 = alpha out_0 + pair[1]
let gate = builder.num_gates();
let out_0 =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_output_0());
acc = builder
.double_arithmetic_extension(
F::ONE,
F::ONE,
self.base,
acc,
pair[0],
out_0,
pair[1],
)
.1;
}
for gate_pair in gates[..half_length].windows(2) {
// Verifies that the accumulator is passed between gates for the first half of the accumulation.
builder.assert_equal_extension(
ExtensionTarget::from_range(
gate_pair[0],
ArithmeticExtensionGate::<D>::wires_output_0(),
),
ExtensionTarget::from_range(
gate_pair[1],
ArithmeticExtensionGate::<D>::wires_multiplicand_0(),
),
);
}
for gate_pair in gates[half_length..].windows(2) {
// Verifies that the accumulator is passed between gates for the second half of the accumulation.
builder.assert_equal_extension(
ExtensionTarget::from_range(
gate_pair[0],
ArithmeticExtensionGate::<D>::wires_output_1(),
),
ExtensionTarget::from_range(
gate_pair[1],
ArithmeticExtensionGate::<D>::wires_multiplicand_1(),
),
);
}
// Verifies that the starting accumulator for the first half is zero.
builder.assert_equal_extension(
ExtensionTarget::from_range(
gates[0],
ArithmeticExtensionGate::<D>::wires_multiplicand_0(),
),
zero,
);
// Verifies that the final accumulator for the first half is passed as a starting
// accumulator for the second half.
builder.assert_equal_extension(
ExtensionTarget::from_range(
gates[half_length - 1],
ArithmeticExtensionGate::<D>::wires_output_0(),
),
ExtensionTarget::from_range(
gates[0],
ArithmeticExtensionGate::<D>::wires_multiplicand_1(),
),
);
// Return the final accumulator for the second half.
ExtensionTarget::from_range(
gates[half_length - 1],
ArithmeticExtensionGate::<D>::wires_output_1(),
)
acc
}
pub fn shift<F>(
@ -236,71 +167,13 @@ impl<const D: usize> ReducingFactorTarget<D> {
}
}
/// Fills the intermediate accumulator in `ReducingFactorTarget::reduce`.
struct ParallelReductionGenerator<const D: usize> {
base: ExtensionTarget<D>,
padded_iter: Vec<ExtensionTarget<D>>,
gates: Vec<usize>,
half_length: usize,
}
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ParallelReductionGenerator<D> {
fn dependencies(&self) -> Vec<Target> {
// Need only the values and the base.
self.padded_iter
.iter()
.flat_map(|ext| ext.to_target_array())
.chain(self.base.to_target_array())
.collect()
}
fn run_once(&self, witness: &PartialWitness<F>) -> PartialWitness<F> {
let mut pw = PartialWitness::new();
let base = witness.get_extension_target(self.base);
let vs = self
.padded_iter
.iter()
.map(|&ext| witness.get_extension_target(ext))
.collect::<Vec<_>>();
// Computed the intermediate accumulators.
let intermediate_accs = vs
.iter()
.rev()
.scan(F::Extension::ZERO, |acc, &x| {
let tmp = *acc;
*acc = *acc * base + x;
Some(tmp)
})
.collect::<Vec<_>>();
for i in 0..self.half_length {
// Fill the accumulators for the first half.
pw.set_extension_target(
ExtensionTarget::from_range(
self.gates[i],
ArithmeticExtensionGate::<D>::wires_multiplicand_0(),
),
intermediate_accs[i],
);
// Fill the accumulators for the second half.
pw.set_extension_target(
ExtensionTarget::from_range(
self.gates[i],
ArithmeticExtensionGate::<D>::wires_multiplicand_1(),
),
intermediate_accs[self.half_length + i],
);
}
pw
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::circuit_data::CircuitConfig;
use crate::field::crandall_field::CrandallField;
use crate::field::extension_field::quartic::QuarticCrandallField;
use crate::witness::PartialWitness;
fn test_reduce_gadget(n: usize) {
type F = CrandallField;

View File

@ -32,6 +32,18 @@ impl<F: Field> PartialWitness<F> {
witness
}
pub fn singleton_extension_target<const D: usize>(
et: ExtensionTarget<D>,
value: F::Extension,
) -> Self
where
F: Extendable<D>,
{
let mut witness = PartialWitness::new();
witness.set_extension_target(et, value);
witness
}
pub fn is_empty(&self) -> bool {
self.target_values.is_empty()
}