Add interpolation gadgets

This commit is contained in:
wborgeaud 2021-06-11 16:22:29 +02:00
parent 4b1f368e89
commit 5200d70cf0
8 changed files with 199 additions and 50 deletions

View File

@ -240,6 +240,18 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.add_extension(product, c)
}
/// Like `mul_sub`, but for `ExtensionTarget`s. Note that, unlike `mul_sub`, this has no
/// performance benefit over separate muls and subs.
pub fn scalar_mul_sub_extension(
&mut self,
a: Target,
b: ExtensionTarget<D>,
c: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let product = self.scalar_mul_ext(a, b);
self.sub_extension(product, c)
}
/// Returns `a * b`, where `b` is in the extension field and `a` is in the base field.
pub fn scalar_mul_ext(&mut self, a: Target, mut b: ExtensionTarget<D>) -> ExtensionTarget<D> {
for i in 0..D {

View File

@ -29,24 +29,28 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
last_evals: &[ExtensionTarget<D>],
beta: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
todo!()
// debug_assert_eq!(last_evals.len(), 1 << arity_bits);
//
// let g = F::primitive_root_of_unity(arity_bits);
//
// // The evaluation vector needs to be reordered first.
// let mut evals = last_evals.to_vec();
// reverse_index_bits_in_place(&mut evals);
// evals.rotate_left(reverse_bits(old_x_index, arity_bits));
//
// // The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta.
// let points = g
// .powers()
// .zip(evals)
// .map(|(y, e)| ((x * y).into(), e))
// .collect::<Vec<_>>();
// let barycentric_weights = barycentric_weights(&points);
// interpolate(&points, beta, &barycentric_weights)
debug_assert_eq!(last_evals.len(), 1 << arity_bits);
let g = F::primitive_root_of_unity(arity_bits);
// The evaluation vector needs to be reordered first.
let mut evals = last_evals.to_vec();
reverse_index_bits_in_place(&mut evals);
let mut old_x_index_bits = self.split_le(old_x_index, arity_bits);
old_x_index_bits.reverse();
self.rotate_left_from_bits(&old_x_index_bits, &evals, arity_bits);
// The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta.
let points = g
.powers()
.zip(evals)
.map(|(y, e)| {
let yt = self.constant(y);
(self.mul(x, yt), e)
})
.collect::<Vec<_>>();
self.interpolate(&points, beta)
}
fn fri_verify_proof_of_work(
@ -205,8 +209,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let a = alpha_powers.next(self);
ev_zeta_right = self.mul_add_extension(a, t, ev_zeta);
}
let zs_interpol = self.interpolate2([(zeta, ev_zeta), (zeta_right, ev_zeta_right)]);
let interpol_val = zs_interpol.eval(self, subgroup_x);
let interpol_val =
self.interpolate2([(zeta, ev_zeta), (zeta_right, ev_zeta_right)], subgroup_x);
let numerator = self.sub_extension(ev, interpol_val);
let vanish = self.sub_extension(subgroup_x, zeta);
let vanish_right = self.sub_extension(subgroup_x, zeta_right);
@ -238,8 +242,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.mul_add_extension(a, w, acc)
})
.frobenius(self);
let wires_interpol = self.interpolate2([(zeta, wire_eval), (zeta_frob, wire_eval_frob)]);
let interpol_val = wires_interpol.eval(self, subgroup_x);
let interpol_val =
self.interpolate2([(zeta, wire_eval), (zeta_frob, wire_eval_frob)], subgroup_x);
let numerator = self.sub_extension(ev, interpol_val);
let vanish_frob = self.sub_extension(subgroup_x, zeta_frob);
let denominator = self.mul_extension(vanish, vanish_frob);

View File

@ -177,7 +177,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let mut current = base;
let one = self.one();
let mut product = one;
let exponent_bits = self.split_le(exponent);
let exponent_bits = self.split_le(exponent, 64);
for bit in exponent_bits.into_iter() {
product = self.mul_many(&[bit, current, product]);

View File

@ -2,14 +2,104 @@ use crate::circuit_builder::CircuitBuilder;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::gadgets::polynomial::PolynomialCoeffsExtTarget;
use crate::gates::interpolation::InterpolationGate;
use crate::target::Target;
use std::marker::PhantomData;
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Interpolate two points. No need for an `InterpolationGate` since the coefficients
/// of the linear interpolation polynomial can be easily computed with arithmetic operations.
pub fn interpolate2(
&mut self,
points: [(ExtensionTarget<D>, ExtensionTarget<D>); 2],
) -> PolynomialCoeffsExtTarget<D> {
todo!()
interpolation_points: [(ExtensionTarget<D>, ExtensionTarget<D>); 2],
evaluation_point: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
// a0 -> a1
// b0 -> b1
// x -> a1 + (x-a0)*(b1-a1)/(b0-a0)
let x_m_a0 = self.sub_extension(evaluation_point, interpolation_points[0].0);
let b1_m_a1 = self.sub_extension(interpolation_points[1].1, interpolation_points[0].1);
let b0_m_a0 = self.sub_extension(interpolation_points[1].0, interpolation_points[0].0);
let quotient = self.div_unsafe_extension(b1_m_a1, b0_m_a0);
self.mul_add_extension(x_m_a0, quotient, interpolation_points[0].1)
}
/// Interpolate a list of point/evaluation pairs at a given point.
/// Returns the evaluation of the interpolated polynomial at `evaluation_point`.
pub fn interpolate(
&mut self,
interpolation_points: &[(Target, ExtensionTarget<D>)],
evaluation_point: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let gate = InterpolationGate::<F, D> {
num_points: interpolation_points.len(),
_phantom: PhantomData,
};
let gate_index =
self.add_gate_no_constants(InterpolationGate::new(interpolation_points.len()));
for (i, &(p, v)) in interpolation_points.iter().enumerate() {
self.route(p, Target::wire(gate_index, gate.wire_point(i)));
self.route_extension(
v,
ExtensionTarget::from_range(gate_index, gate.wires_value(i)),
);
}
self.route_extension(
evaluation_point,
ExtensionTarget::from_range(gate_index, gate.wires_evaluation_point()),
);
ExtensionTarget::from_range(gate_index, gate.wires_evaluation_value())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::circuit_data::CircuitConfig;
use crate::field::crandall_field::CrandallField;
use crate::field::extension_field::quartic::QuarticCrandallField;
use crate::field::extension_field::FieldExtension;
use crate::field::field::Field;
use crate::field::lagrange::{interpolant, interpolate};
use crate::witness::PartialWitness;
#[test]
fn test_interpolate() {
type F = CrandallField;
type FF = QuarticCrandallField;
let config = CircuitConfig::large_config();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let len = 2;
let points = (0..len)
.map(|_| (F::rand(), FF::rand()))
.collect::<Vec<_>>();
let homogeneous_points = points
.iter()
.map(|&(a, b)| (<FF as FieldExtension<4>>::from_basefield(a), b))
.collect::<Vec<_>>();
let true_interpolant = interpolant(&homogeneous_points);
let z = FF::rand();
let true_eval = true_interpolant.eval(z);
let points_target = points
.iter()
.map(|&(p, v)| (builder.constant(p), builder.constant_extension(v)))
.collect::<Vec<_>>();
let zt = builder.constant_extension(z);
let eval = builder.interpolate(&points_target, zt);
let true_eval_target = builder.constant_extension(true_eval);
builder.assert_equal_extension(eval, true_eval_target);
let data = builder.build();
let proof = data.prove(PartialWitness::new());
}
}

View File

@ -1,4 +1,5 @@
use crate::circuit_builder::CircuitBuilder;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field::Field;
use crate::gates::base_sum::BaseSumGate;
@ -10,13 +11,24 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Selects `x` or `y` based on `b`, which is assumed to be binary.
/// In particular, this returns `if b { x } else { y }`.
/// Note: This does not range-check `b`.
pub fn select(&mut self, b: Target, x: Target, y: Target) -> Target {
let b_y_minus_y = self.mul_sub(b, y, y);
self.mul_sub(b, x, b_y_minus_y)
pub fn select(
&mut self,
b: Target,
x: ExtensionTarget<D>,
y: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let b_y_minus_y = self.scalar_mul_sub_extension(b, y, y);
self.scalar_mul_sub_extension(b, x, b_y_minus_y)
}
/// Left-rotates an array `k` times if `b=1` else return the same array.
pub fn rotate_fixed(&mut self, b: Target, k: usize, v: &[Target], len: usize) -> Vec<Target> {
pub fn rotate_left_fixed(
&mut self,
b: Target,
k: usize,
v: &[ExtensionTarget<D>],
len: usize,
) -> Vec<ExtensionTarget<D>> {
let mut res = Vec::new();
for i in 0..len {
@ -29,16 +41,40 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Left-rotates an array by `num_rotation`. Assumes that `num_rotation` is range-checked to be
/// less than `len`.
/// Note: We assume `len` is less than 8 since we won't use any arity greater than 8 in FRI (maybe?).
pub fn rotate(&mut self, num_rotation: Target, v: &[Target], len: usize) -> Vec<Target> {
pub fn rotate_left_from_bits(
&mut self,
num_rotation_bits: &[Target],
v: &[ExtensionTarget<D>],
len_log: usize,
) -> Vec<ExtensionTarget<D>> {
debug_assert_eq!(num_rotation_bits.len(), len_log);
let len = 1 << len_log;
debug_assert_eq!(v.len(), len);
let bits = self.split_le_base::<2>(num_rotation, 3);
let mut v = v.to_vec();
let v = self.rotate_fixed(bits[0], 1, v, len);
let v = self.rotate_fixed(bits[1], 2, &v, len);
let v = self.rotate_fixed(bits[2], 4, &v, len);
for i in 0..len_log {
v = self.rotate_left_fixed(num_rotation_bits[i], 1 << i, &v, len);
}
v
}
/// Left-rotates an array by `num_rotation`. Assumes that `num_rotation` is range-checked to be
/// less than `len`.
/// Note: We assume `len` is a power of two less than or equal to 8, since we won't use any
/// arity greater than 8 in FRI (maybe?).
pub fn rotate_left(
&mut self,
num_rotation: Target,
v: &[ExtensionTarget<D>],
len_log: usize,
) -> Vec<ExtensionTarget<D>> {
let len = 1 << len_log;
debug_assert_eq!(v.len(), len);
let bits = self.split_le(num_rotation, len_log);
self.rotate_left_from_bits(&bits, v, len_log)
}
}
#[cfg(test)]
@ -46,28 +82,34 @@ mod tests {
use super::*;
use crate::circuit_data::CircuitConfig;
use crate::field::crandall_field::CrandallField;
use crate::field::extension_field::quartic::QuarticCrandallField;
fn real_rotate(num_rotation: usize, v: &[Target]) -> Vec<Target> {
fn real_rotate<const D: usize>(
num_rotation: usize,
v: &[ExtensionTarget<D>],
) -> Vec<ExtensionTarget<D>> {
let mut res = v.to_vec();
res.rotate_left(num_rotation);
res
}
fn test_rotate_given_len(len: usize) {
fn test_rotate_given_len(len_log: usize) {
type F = CrandallField;
type FF = QuarticCrandallField;
let len = 1 << len_log;
let config = CircuitConfig::large_config();
let mut builder = CircuitBuilder::<F, 4>::new(config);
let v = (0..len)
.map(|_| builder.constant(F::rand()))
.collect::<Vec<_>>(); // 416 = 1532 in base 6.
.map(|_| builder.constant_extension(FF::rand()))
.collect::<Vec<_>>();
for i in 0..len {
let it = builder.constant(F::from_canonical_usize(i));
let rotated = real_rotate(i, &v);
let purported_rotated = builder.rotate(it, &v, len);
let purported_rotated = builder.rotate_left(it, &v, len_log);
for (x, y) in rotated.into_iter().zip(purported_rotated) {
builder.assert_equal(x, y);
builder.assert_equal_extension(x, y);
}
}
@ -77,8 +119,8 @@ mod tests {
#[test]
fn test_rotate() {
for i_log in 1..4 {
test_rotate_given_len(1 << i_log);
for len_log in 1..4 {
test_rotate_given_len(len_log);
}
}
}

View File

@ -27,21 +27,22 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Split the given integer into a list of wires, where each one represents a
/// bit of the integer, with little-endian ordering.
/// Verifies that the decomposition is correct by using `k` `BaseSum<2>` gates
/// with `k` such that `k*num_routed_bits>=64`.
pub(crate) fn split_le(&mut self, integer: Target) -> Vec<Target> {
/// with `k` such that `k*num_routed_wires>=num_bits`.
pub(crate) fn split_le(&mut self, integer: Target, num_bits: usize) -> Vec<Target> {
let num_limbs = self.config.num_routed_wires - BaseSumGate::<2>::WIRE_LIMBS_START;
let k = ceil_div_usize(64, num_limbs);
let k = ceil_div_usize(num_bits, num_limbs);
let gates = (0..k)
.map(|_| self.add_gate_no_constants(BaseSumGate::<2>::new(num_limbs)))
.collect::<Vec<_>>();
let mut bits = Vec::with_capacity(64);
let mut bits = Vec::with_capacity(num_bits);
for &gate in &gates {
bits.extend(Target::wires_from_range(
gate,
BaseSumGate::<2>::WIRE_LIMBS_START..BaseSumGate::<2>::WIRE_LIMBS_START + num_limbs,
));
}
bits.drain(num_bits..);
let zero = self.zero();
let mut acc = zero;

View File

@ -22,8 +22,8 @@ use crate::witness::PartialWitness;
/// given point.
#[derive(Clone, Debug)]
pub(crate) struct InterpolationGate<F: Extendable<D>, const D: usize> {
num_points: usize,
_phantom: PhantomData<F>,
pub num_points: usize,
pub _phantom: PhantomData<F>,
}
impl<F: Extendable<D>, const D: usize> InterpolationGate<F, D> {

View File

@ -3,7 +3,7 @@ pub mod base_sum;
pub mod constant;
pub(crate) mod gate;
pub mod gmimc;
mod interpolation;
pub mod interpolation;
pub mod mul_extension;
pub(crate) mod noop;