Specialize InterpolationGate (#339)

* Specialize `InterpolationGate`

To cosets of subgroups of roots of unity. This way
- `InterpolationGate` needs fewer routed wires, bringing our minimum routed wires down from 28 to 25.
- The recursive `compute_evaluation` avoids some multiplications, saving 100~200 gates depending on `num_routed_wires`.

* Update test

* feedback
This commit is contained in:
Daniel Lubarov 2021-11-05 09:29:08 -07:00 committed by GitHub
parent 75fe5686a2
commit 671bb9be2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 117 additions and 95 deletions

View File

@ -43,16 +43,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let coset_start = self.mul(start, x); let coset_start = self.mul(start, x);
// The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta. // The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta.
let points = g self.interpolate_coset(arity_bits, coset_start, &evals, beta)
.powers()
.map(|y| {
let yc = self.constant(y);
self.mul(coset_start, yc)
})
.zip(evals)
.collect::<Vec<_>>();
self.interpolate(&points, beta)
} }
/// Make sure we have enough wires and routed wires to do the FRI checks efficiently. This check /// Make sure we have enough wires and routed wires to do the FRI checks efficiently. This check
@ -63,7 +54,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
&self.config, &self.config,
max_fri_arity.max(1 << self.config.cap_height), max_fri_arity.max(1 << self.config.cap_height),
); );
let interpolation_gate = InterpolationGate::<F, D>::new(max_fri_arity); let interpolation_gate = InterpolationGate::<F, D>::new(log2_strict(max_fri_arity));
let min_wires = random_access let min_wires = random_access
.num_wires() .num_wires()

View File

@ -6,17 +6,20 @@ use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_builder::CircuitBuilder;
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> { impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Interpolate a list of point/evaluation pairs at a given point. /// Interpolates a polynomial, whose points are a coset of the multiplicative subgroup with the
/// Returns the evaluation of the interpolated polynomial at `evaluation_point`. /// given size, and whose values are given. Returns the evaluation of the interpolant at
pub fn interpolate( /// `evaluation_point`.
pub fn interpolate_coset(
&mut self, &mut self,
interpolation_points: &[(Target, ExtensionTarget<D>)], subgroup_bits: usize,
coset_shift: Target,
values: &[ExtensionTarget<D>],
evaluation_point: ExtensionTarget<D>, evaluation_point: ExtensionTarget<D>,
) -> ExtensionTarget<D> { ) -> ExtensionTarget<D> {
let gate = InterpolationGate::new(interpolation_points.len()); let gate = InterpolationGate::new(subgroup_bits);
let gate_index = self.add_gate(gate.clone(), vec![]); let gate_index = self.add_gate(gate.clone(), vec![]);
for (i, &(p, v)) in interpolation_points.iter().enumerate() { self.connect(coset_shift, Target::wire(gate_index, gate.wire_shift()));
self.connect(p, Target::wire(gate_index, gate.wire_point(i))); for (i, &v) in values.iter().enumerate() {
self.connect_extension( self.connect_extension(
v, v,
ExtensionTarget::from_range(gate_index, gate.wires_value(i)), ExtensionTarget::from_range(gate_index, gate.wires_value(i)),
@ -53,14 +56,17 @@ mod tests {
let pw = PartialWitness::new(); let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, 4>::new(config); let mut builder = CircuitBuilder::<F, 4>::new(config);
let len = 4; let subgroup_bits = 2;
let points = (0..len) let len = 1 << subgroup_bits;
.map(|_| (F::rand(), FF::rand())) let coset_shift = F::rand();
.collect::<Vec<_>>(); let g = F::primitive_root_of_unity(subgroup_bits);
let points = F::cyclic_subgroup_coset_known_order(g, coset_shift, len);
let values = FF::rand_vec(len);
let homogeneous_points = points let homogeneous_points = points
.iter() .iter()
.map(|&(a, b)| (<FF as FieldExtension<4>>::from_basefield(a), b)) .zip(values.iter())
.map(|(&a, &b)| (<FF as FieldExtension<4>>::from_basefield(a), b))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let true_interpolant = interpolant(&homogeneous_points); let true_interpolant = interpolant(&homogeneous_points);
@ -68,14 +74,16 @@ mod tests {
let z = FF::rand(); let z = FF::rand();
let true_eval = true_interpolant.eval(z); let true_eval = true_interpolant.eval(z);
let points_target = points let coset_shift_target = builder.constant(coset_shift);
let value_targets = values
.iter() .iter()
.map(|&(p, v)| (builder.constant(p), builder.constant_extension(v))) .map(|&v| (builder.constant_extension(v)))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let zt = builder.constant_extension(z); let zt = builder.constant_extension(z);
let eval = builder.interpolate(&points_target, zt); let eval = builder.interpolate_coset(subgroup_bits, coset_shift_target, &value_targets, zt);
let true_eval_target = builder.constant_extension(true_eval); let true_eval_target = builder.constant_extension(true_eval);
builder.connect_extension(eval, true_eval_target); builder.connect_extension(eval, true_eval_target);

View File

@ -242,7 +242,7 @@ mod tests {
GateRef::new(ArithmeticExtensionGate { num_ops: 4 }), GateRef::new(ArithmeticExtensionGate { num_ops: 4 }),
GateRef::new(BaseSumGate::<4>::new(4)), GateRef::new(BaseSumGate::<4>::new(4)),
GateRef::new(GMiMCGate::<F, D, 12>::new()), GateRef::new(GMiMCGate::<F, D, 12>::new()),
GateRef::new(InterpolationGate::new(4)), GateRef::new(InterpolationGate::new(2)),
]; ];
let (tree, _, _) = Tree::from_gates(gates.clone()); let (tree, _, _) = Tree::from_gates(gates.clone());

View File

@ -17,48 +17,45 @@ use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
use crate::polynomial::polynomial::PolynomialCoeffs; use crate::polynomial::polynomial::PolynomialCoeffs;
/// Evaluates the interpolant of some given elements from a field extension. /// Interpolates a polynomial, whose points are a (base field) coset of the multiplicative subgroup
/// /// with the given size, and whose values are extension field elements, given by input wires.
/// In particular, this gate takes as inputs `num_points` points, `num_points` values, and the point /// Outputs the evaluation of the interpolant at a given (extension field) evaluation point.
/// to evaluate the interpolant at. It computes the interpolant and outputs its evaluation at the
/// given point.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub(crate) struct InterpolationGate<F: RichField + Extendable<D>, const D: usize> { pub(crate) struct InterpolationGate<F: RichField + Extendable<D>, const D: usize> {
pub num_points: usize, pub subgroup_bits: usize,
_phantom: PhantomData<F>, _phantom: PhantomData<F>,
} }
impl<F: RichField + Extendable<D>, const D: usize> InterpolationGate<F, D> { impl<F: RichField + Extendable<D>, const D: usize> InterpolationGate<F, D> {
pub fn new(num_points: usize) -> Self { pub fn new(subgroup_bits: usize) -> Self {
Self { Self {
num_points, subgroup_bits,
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
fn start_points(&self) -> usize { fn num_points(&self) -> usize {
1 << self.subgroup_bits
}
/// Wire index of the coset shift.
pub fn wire_shift(&self) -> usize {
0 0
} }
/// Wire indices of the `i`th interpolant point.
pub fn wire_point(&self, i: usize) -> usize {
debug_assert!(i < self.num_points);
self.start_points() + i
}
fn start_values(&self) -> usize { fn start_values(&self) -> usize {
self.start_points() + self.num_points 1
} }
/// Wire indices of the `i`th interpolant value. /// Wire indices of the `i`th interpolant value.
pub fn wires_value(&self, i: usize) -> Range<usize> { pub fn wires_value(&self, i: usize) -> Range<usize> {
debug_assert!(i < self.num_points); debug_assert!(i < self.num_points());
let start = self.start_values() + i * D; let start = self.start_values() + i * D;
start..start + D start..start + D
} }
fn start_evaluation_point(&self) -> usize { fn start_evaluation_point(&self) -> usize {
self.start_values() + self.num_points * D self.start_values() + self.num_points() * D
} }
/// Wire indices of the point to evaluate the interpolant at. /// Wire indices of the point to evaluate the interpolant at.
@ -89,14 +86,46 @@ impl<F: RichField + Extendable<D>, const D: usize> InterpolationGate<F, D> {
/// Wire indices of the interpolant's `i`th coefficient. /// Wire indices of the interpolant's `i`th coefficient.
pub fn wires_coeff(&self, i: usize) -> Range<usize> { pub fn wires_coeff(&self, i: usize) -> Range<usize> {
debug_assert!(i < self.num_points); debug_assert!(i < self.num_points());
let start = self.start_coeffs() + i * D; let start = self.start_coeffs() + i * D;
start..start + D start..start + D
} }
/// End of wire indices, exclusive. /// End of wire indices, exclusive.
fn end(&self) -> usize { fn end(&self) -> usize {
self.start_coeffs() + self.num_points * D self.start_coeffs() + self.num_points() * D
}
/// The domain of the points we're interpolating.
fn coset(&self, shift: F) -> impl Iterator<Item = F> {
let g = F::primitive_root_of_unity(self.subgroup_bits);
let size = 1 << self.subgroup_bits;
// Speed matters here, so we avoid `cyclic_subgroup_coset_known_order` which allocates.
g.powers().take(size).map(move |x| x * shift)
}
/// The domain of the points we're interpolating.
fn coset_ext(&self, shift: F::Extension) -> impl Iterator<Item = F::Extension> {
let g = F::primitive_root_of_unity(self.subgroup_bits);
let size = 1 << self.subgroup_bits;
g.powers().take(size).map(move |x| shift.scalar_mul(x))
}
/// The domain of the points we're interpolating.
fn coset_ext_recursive(
&self,
builder: &mut CircuitBuilder<F, D>,
shift: ExtensionTarget<D>,
) -> Vec<ExtensionTarget<D>> {
let g = F::primitive_root_of_unity(self.subgroup_bits);
let size = 1 << self.subgroup_bits;
g.powers()
.take(size)
.map(move |x| {
let subgroup_element = builder.constant(x.into());
builder.scalar_mul_ext(subgroup_element, shift)
})
.collect()
} }
} }
@ -108,13 +137,13 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for InterpolationG
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> { fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let mut constraints = Vec::with_capacity(self.num_constraints()); let mut constraints = Vec::with_capacity(self.num_constraints());
let coeffs = (0..self.num_points) let coeffs = (0..self.num_points())
.map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i)))
.collect(); .collect();
let interpolant = PolynomialCoeffsAlgebra::new(coeffs); let interpolant = PolynomialCoeffsAlgebra::new(coeffs);
for i in 0..self.num_points { let coset = self.coset_ext(vars.local_wires[self.wire_shift()]);
let point = vars.local_wires[self.wire_point(i)]; for (i, point) in coset.into_iter().enumerate() {
let value = vars.get_local_ext_algebra(self.wires_value(i)); let value = vars.get_local_ext_algebra(self.wires_value(i));
let computed_value = interpolant.eval_base(point); let computed_value = interpolant.eval_base(point);
constraints.extend(&(value - computed_value).to_basefield_array()); constraints.extend(&(value - computed_value).to_basefield_array());
@ -131,13 +160,13 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for InterpolationG
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> { fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
let mut constraints = Vec::with_capacity(self.num_constraints()); let mut constraints = Vec::with_capacity(self.num_constraints());
let coeffs = (0..self.num_points) let coeffs = (0..self.num_points())
.map(|i| vars.get_local_ext(self.wires_coeff(i))) .map(|i| vars.get_local_ext(self.wires_coeff(i)))
.collect(); .collect();
let interpolant = PolynomialCoeffs::new(coeffs); let interpolant = PolynomialCoeffs::new(coeffs);
for i in 0..self.num_points { let coset = self.coset(vars.local_wires[self.wire_shift()]);
let point = vars.local_wires[self.wire_point(i)]; for (i, point) in coset.into_iter().enumerate() {
let value = vars.get_local_ext(self.wires_value(i)); let value = vars.get_local_ext(self.wires_value(i));
let computed_value = interpolant.eval_base(point); let computed_value = interpolant.eval_base(point);
constraints.extend(&(value - computed_value).to_basefield_array()); constraints.extend(&(value - computed_value).to_basefield_array());
@ -158,13 +187,13 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for InterpolationG
) -> Vec<ExtensionTarget<D>> { ) -> Vec<ExtensionTarget<D>> {
let mut constraints = Vec::with_capacity(self.num_constraints()); let mut constraints = Vec::with_capacity(self.num_constraints());
let coeffs = (0..self.num_points) let coeffs = (0..self.num_points())
.map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i)))
.collect(); .collect();
let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs); let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs);
for i in 0..self.num_points { let coset = self.coset_ext_recursive(builder, vars.local_wires[self.wire_shift()]);
let point = vars.local_wires[self.wire_point(i)]; for (i, point) in coset.into_iter().enumerate() {
let value = vars.get_local_ext_algebra(self.wires_value(i)); let value = vars.get_local_ext_algebra(self.wires_value(i));
let computed_value = interpolant.eval_scalar(builder, point); let computed_value = interpolant.eval_scalar(builder, point);
constraints.extend( constraints.extend(
@ -210,13 +239,13 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for InterpolationG
fn degree(&self) -> usize { fn degree(&self) -> usize {
// The highest power of x is `num_points - 1`, and then multiplication by the coefficient // The highest power of x is `num_points - 1`, and then multiplication by the coefficient
// adds 1. // adds 1.
self.num_points self.num_points()
} }
fn num_constraints(&self) -> usize { fn num_constraints(&self) -> usize {
// num_points * D constraints to check for consistency between the coefficients and the // num_points * D constraints to check for consistency between the coefficients and the
// point-value pairs, plus D constraints for the evaluation value. // point-value pairs, plus D constraints for the evaluation value.
self.num_points * D + D self.num_points() * D + D
} }
} }
@ -240,18 +269,18 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
let local_targets = |inputs: Range<usize>| inputs.map(local_target); let local_targets = |inputs: Range<usize>| inputs.map(local_target);
let mut deps = Vec::new(); let num_points = self.gate.num_points();
let mut deps = Vec::with_capacity(1 + D + num_points * D);
deps.push(local_target(self.gate.wire_shift()));
deps.extend(local_targets(self.gate.wires_evaluation_point())); deps.extend(local_targets(self.gate.wires_evaluation_point()));
for i in 0..self.gate.num_points { for i in 0..num_points {
deps.push(local_target(self.gate.wire_point(i)));
deps.extend(local_targets(self.gate.wires_value(i))); deps.extend(local_targets(self.gate.wires_value(i)));
} }
deps deps
} }
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) { fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
let n = self.gate.num_points;
let local_wire = |input| Wire { let local_wire = |input| Wire {
gate: self.gate_index, gate: self.gate_index,
input, input,
@ -267,13 +296,11 @@ impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
}; };
// Compute the interpolant. // Compute the interpolant.
let points = (0..n) let points = self.gate.coset(get_local_wire(self.gate.wire_shift()));
.map(|i| { let points = points
( .into_iter()
F::Extension::from_basefield(get_local_wire(self.gate.wire_point(i))), .enumerate()
get_local_ext(self.gate.wires_value(i)), .map(|(i, point)| (point.into(), get_local_ext(self.gate.wires_value(i))))
)
})
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let interpolant = interpolant(&points); let interpolant = interpolant(&points);
@ -308,31 +335,30 @@ mod tests {
#[test] #[test]
fn wire_indices() { fn wire_indices() {
let gate = InterpolationGate::<GoldilocksField, 4> { let gate = InterpolationGate::<GoldilocksField, 4> {
num_points: 2, subgroup_bits: 1,
_phantom: PhantomData, _phantom: PhantomData,
}; };
// The exact indices aren't really important, but we want to make sure we don't have any // The exact indices aren't really important, but we want to make sure we don't have any
// overlaps or gaps. // overlaps or gaps.
assert_eq!(gate.wire_point(0), 0); assert_eq!(gate.wire_shift(), 0);
assert_eq!(gate.wire_point(1), 1); assert_eq!(gate.wires_value(0), 1..5);
assert_eq!(gate.wires_value(0), 2..6); assert_eq!(gate.wires_value(1), 5..9);
assert_eq!(gate.wires_value(1), 6..10); assert_eq!(gate.wires_evaluation_point(), 9..13);
assert_eq!(gate.wires_evaluation_point(), 10..14); assert_eq!(gate.wires_evaluation_value(), 13..17);
assert_eq!(gate.wires_evaluation_value(), 14..18); assert_eq!(gate.wires_coeff(0), 17..21);
assert_eq!(gate.wires_coeff(0), 18..22); assert_eq!(gate.wires_coeff(1), 21..25);
assert_eq!(gate.wires_coeff(1), 22..26); assert_eq!(gate.num_wires(), 25);
assert_eq!(gate.num_wires(), 26);
} }
#[test] #[test]
fn low_degree() { fn low_degree() {
test_low_degree::<GoldilocksField, _, 4>(InterpolationGate::new(4)); test_low_degree::<GoldilocksField, _, 4>(InterpolationGate::new(2));
} }
#[test] #[test]
fn eval_fns() -> Result<()> { fn eval_fns() -> Result<()> {
test_eval_fns::<GoldilocksField, _, 4>(InterpolationGate::new(4)) test_eval_fns::<GoldilocksField, _, 4>(InterpolationGate::new(2))
} }
#[test] #[test]
@ -343,15 +369,15 @@ mod tests {
/// Returns the local wires for an interpolation gate for given coeffs, points and eval point. /// Returns the local wires for an interpolation gate for given coeffs, points and eval point.
fn get_wires( fn get_wires(
num_points: usize, gate: &InterpolationGate<F, D>,
shift: F,
coeffs: PolynomialCoeffs<FF>, coeffs: PolynomialCoeffs<FF>,
points: Vec<F>,
eval_point: FF, eval_point: FF,
) -> Vec<FF> { ) -> Vec<FF> {
let mut v = Vec::new(); let points = gate.coset(shift);
v.extend_from_slice(&points); let mut v = vec![shift];
for j in 0..num_points { for x in points {
v.extend(coeffs.eval(points[j].into()).0); v.extend(coeffs.eval(x.into()).0);
} }
v.extend(eval_point.0); v.extend(eval_point.0);
v.extend(coeffs.eval(eval_point).0); v.extend(coeffs.eval(eval_point).0);
@ -362,16 +388,13 @@ mod tests {
} }
// Get a working row for InterpolationGate. // Get a working row for InterpolationGate.
let shift = F::rand();
let coeffs = PolynomialCoeffs::new(vec![FF::rand(), FF::rand()]); let coeffs = PolynomialCoeffs::new(vec![FF::rand(), FF::rand()]);
let points = vec![F::rand(), F::rand()];
let eval_point = FF::rand(); let eval_point = FF::rand();
let gate = InterpolationGate::<F, D> { let gate = InterpolationGate::<F, D>::new(1);
num_points: 2,
_phantom: PhantomData,
};
let vars = EvaluationVars { let vars = EvaluationVars {
local_constants: &[], local_constants: &[],
local_wires: &get_wires(2, coeffs, points, eval_point), local_wires: &get_wires(&gate, shift, coeffs, eval_point),
public_inputs_hash: &HashOut::rand(), public_inputs_hash: &HashOut::rand(),
}; };

View File

@ -53,7 +53,7 @@ impl CircuitConfig {
pub(crate) fn standard_recursion_config() -> Self { pub(crate) fn standard_recursion_config() -> Self {
Self { Self {
num_wires: 143, num_wires: 143,
num_routed_wires: 28, num_routed_wires: 25,
constant_gate_size: 6, constant_gate_size: 6,
security_bits: 100, security_bits: 100,
rate_bits: 3, rate_bits: 3,