Use arithmetic gate for small reductions

This commit is contained in:
wborgeaud 2021-11-15 11:38:48 +01:00
parent 66719b0cfc
commit a54db66f68
2 changed files with 85 additions and 26 deletions

View File

@ -13,11 +13,11 @@ use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// Computes `sum alpha^i c_i` for a vector `c_i` of `num_coeffs` elements of the extension field.
#[derive(Debug, Clone)]
pub struct ReducingExtGate<const D: usize> {
pub struct ReducingExtensionGate<const D: usize> {
pub num_coeffs: usize,
}
impl<const D: usize> ReducingExtGate<D> {
impl<const D: usize> ReducingExtensionGate<D> {
pub fn new(num_coeffs: usize) -> Self {
Self { num_coeffs }
}
@ -51,7 +51,7 @@ impl<const D: usize> ReducingExtGate<D> {
}
}
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ReducingExtGate<D> {
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ReducingExtensionGate<D> {
fn id(&self) -> String {
format!("{:?}", self)
}
@ -163,14 +163,16 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ReducingExtGat
#[derive(Debug)]
struct ReducingGenerator<const D: usize> {
gate_index: usize,
gate: ReducingExtGate<D>,
gate: ReducingExtensionGate<D>,
}
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ReducingGenerator<D> {
fn dependencies(&self) -> Vec<Target> {
ReducingExtGate::<D>::wires_alpha()
.chain(ReducingExtGate::<D>::wires_old_acc())
.chain((0..self.gate.num_coeffs).flat_map(|i| ReducingExtGate::<D>::wires_coeff(i)))
ReducingExtensionGate::<D>::wires_alpha()
.chain(ReducingExtensionGate::<D>::wires_old_acc())
.chain(
(0..self.gate.num_coeffs).flat_map(|i| ReducingExtensionGate::<D>::wires_coeff(i)),
)
.map(|i| Target::wire(self.gate_index, i))
.collect()
}
@ -181,16 +183,18 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ReducingGenerator<
witness.get_extension_target(t)
};
let alpha = extract_extension(ReducingExtGate::<D>::wires_alpha());
let old_acc = extract_extension(ReducingExtGate::<D>::wires_old_acc());
let alpha = extract_extension(ReducingExtensionGate::<D>::wires_alpha());
let old_acc = extract_extension(ReducingExtensionGate::<D>::wires_old_acc());
let coeffs = (0..self.gate.num_coeffs)
.map(|i| extract_extension(ReducingExtGate::<D>::wires_coeff(i)))
.map(|i| extract_extension(ReducingExtensionGate::<D>::wires_coeff(i)))
.collect::<Vec<_>>();
let accs = (0..self.gate.num_coeffs)
.map(|i| ExtensionTarget::from_range(self.gate_index, self.gate.wires_accs(i)))
.collect::<Vec<_>>();
let output =
ExtensionTarget::from_range(self.gate_index, ReducingExtGate::<D>::wires_output());
let output = ExtensionTarget::from_range(
self.gate_index,
ReducingExtensionGate::<D>::wires_output(),
);
let mut acc = old_acc;
for i in 0..self.gate.num_coeffs {
@ -208,15 +212,15 @@ mod tests {
use crate::field::goldilocks_field::GoldilocksField;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::gates::reducing_extension::ReducingExtGate;
use crate::gates::reducing_extension::ReducingExtensionGate;
#[test]
fn low_degree() {
test_low_degree::<GoldilocksField, _, 4>(ReducingExtGate::new(22));
test_low_degree::<GoldilocksField, _, 4>(ReducingExtensionGate::new(22));
}
#[test]
fn eval_fns() -> Result<()> {
test_eval_fns::<GoldilocksField, _, 4>(ReducingExtGate::new(22))
test_eval_fns::<GoldilocksField, _, 4>(ReducingExtensionGate::new(22))
}
}

View File

@ -3,8 +3,9 @@ use std::borrow::Borrow;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, RichField};
use crate::gates::arithmetic_extension::ArithmeticExtensionGate;
use crate::gates::reducing::ReducingGate;
use crate::gates::reducing_extension::ReducingExtGate;
use crate::gates::reducing_extension::ReducingExtensionGate;
use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::polynomial::polynomial::PolynomialCoeffs;
@ -94,7 +95,7 @@ impl<const D: usize> ReducingFactorTarget<D> {
Self { base, count: 0 }
}
/// Reduces a length `n` vector of `Target`s using `n/21` `ReducingGate`s (with 33 routed wires and 126 wires).
/// Reduces a vector of `Target`s using `ReducingGate`s.
pub fn reduce_base<F>(
&mut self,
terms: &[Target],
@ -103,11 +104,16 @@ impl<const D: usize> ReducingFactorTarget<D> {
where
F: RichField + Extendable<D>,
{
let l = terms.len();
// For small reductions, use an arithmetic gate.
if l - 1 <= ArithmeticExtensionGate::<D>::new_from_config(&builder.config).num_ops {
return self.reduce_base_arithmetic(terms, builder);
}
let max_coeffs_len = ReducingGate::<D>::max_coeffs_len(
builder.config.num_wires,
builder.config.num_routed_wires,
);
self.count += terms.len() as u64;
self.count += l as u64;
let zero = builder.zero();
let zero_ext = builder.zero_extension();
let mut acc = zero_ext;
@ -138,6 +144,26 @@ impl<const D: usize> ReducingFactorTarget<D> {
acc
}
/// Reduces a vector of `Target`s using `ArithmeticGate`s.
fn reduce_base_arithmetic<F>(
&mut self,
terms: &[Target],
builder: &mut CircuitBuilder<F, D>,
) -> ExtensionTarget<D>
where
F: RichField + Extendable<D>,
{
self.count += terms.len() as u64;
terms
.iter()
.rev()
.fold(builder.zero_extension(), |acc, &t| {
let et = builder.convert_to_ext(t);
builder.mul_add_extension(self.base, acc, et)
})
}
/// Reduces a vector of `ExtensionTarget`s using `ReducingExtensionGate`s.
pub fn reduce<F>(
&mut self,
terms: &[ExtensionTarget<D>], // Could probably work with a `DoubleEndedIterator` too.
@ -146,12 +172,16 @@ impl<const D: usize> ReducingFactorTarget<D> {
where
F: RichField + Extendable<D>,
{
let max_coeffs_len = ReducingExtGate::<D>::max_coeffs_len(
let l = terms.len();
// For small reductions, use an arithmetic gate.
if l - 1 <= ArithmeticExtensionGate::<D>::new_from_config(&builder.config).num_ops {
return self.reduce_arithmetic(terms, builder);
}
let max_coeffs_len = ReducingExtensionGate::<D>::max_coeffs_len(
builder.config.num_wires,
builder.config.num_routed_wires,
);
self.count += terms.len() as u64;
let zero = builder.zero();
self.count += l as u64;
let zero_ext = builder.zero_extension();
let mut acc = zero_ext;
let mut reversed_terms = terms.to_vec();
@ -160,30 +190,55 @@ impl<const D: usize> ReducingFactorTarget<D> {
}
reversed_terms.reverse();
for chunk in reversed_terms.chunks_exact(max_coeffs_len) {
let gate = ReducingExtGate::new(max_coeffs_len);
let gate = ReducingExtensionGate::new(max_coeffs_len);
let gate_index = builder.add_gate(gate.clone(), Vec::new());
builder.connect_extension(
self.base,
ExtensionTarget::from_range(gate_index, ReducingExtGate::<D>::wires_alpha()),
ExtensionTarget::from_range(gate_index, ReducingExtensionGate::<D>::wires_alpha()),
);
builder.connect_extension(
acc,
ExtensionTarget::from_range(gate_index, ReducingExtGate::<D>::wires_old_acc()),
ExtensionTarget::from_range(
gate_index,
ReducingExtensionGate::<D>::wires_old_acc(),
),
);
for (i, &t) in chunk.iter().enumerate() {
builder.connect_extension(
t,
ExtensionTarget::from_range(gate_index, ReducingExtGate::<D>::wires_coeff(i)),
ExtensionTarget::from_range(
gate_index,
ReducingExtensionGate::<D>::wires_coeff(i),
),
);
}
acc = ExtensionTarget::from_range(gate_index, ReducingExtGate::<D>::wires_output());
acc =
ExtensionTarget::from_range(gate_index, ReducingExtensionGate::<D>::wires_output());
}
acc
}
/// Reduces a vector of `ExtensionTarget`s using `ArithmeticGate`s.
fn reduce_arithmetic<F>(
&mut self,
terms: &[ExtensionTarget<D>],
builder: &mut CircuitBuilder<F, D>,
) -> ExtensionTarget<D>
where
F: RichField + Extendable<D>,
{
self.count += terms.len() as u64;
terms
.iter()
.rev()
.fold(builder.zero_extension(), |acc, &et| {
builder.mul_add_extension(self.base, acc, et)
})
}
pub fn shift<F>(
&mut self,
x: ExtensionTarget<D>,