From 5200d70cf0832d251962e0e0c967a04183052d63 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 11 Jun 2021 16:22:29 +0200 Subject: [PATCH] Add interpolation gadgets --- src/field/extension_field/target.rs | 12 ++++ src/fri/recursive_verifier.rs | 48 ++++++++------- src/gadgets/arithmetic.rs | 2 +- src/gadgets/interpolation.rs | 96 ++++++++++++++++++++++++++++- src/gadgets/rotate.rs | 76 ++++++++++++++++++----- src/gadgets/split_join.rs | 9 +-- src/gates/interpolation.rs | 4 +- src/gates/mod.rs | 2 +- 8 files changed, 199 insertions(+), 50 deletions(-) diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index 2e743098..5a3cd403 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -240,6 +240,18 @@ impl, const D: usize> CircuitBuilder { self.add_extension(product, c) } + /// Like `mul_sub`, but for `ExtensionTarget`s. Note that, unlike `mul_sub`, this has no + /// performance benefit over separate muls and subs. + pub fn scalar_mul_sub_extension( + &mut self, + a: Target, + b: ExtensionTarget, + c: ExtensionTarget, + ) -> ExtensionTarget { + let product = self.scalar_mul_ext(a, b); + self.sub_extension(product, c) + } + /// Returns `a * b`, where `b` is in the extension field and `a` is in the base field. pub fn scalar_mul_ext(&mut self, a: Target, mut b: ExtensionTarget) -> ExtensionTarget { for i in 0..D { diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index 3a4332cf..9eb1a1c0 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -29,24 +29,28 @@ impl, const D: usize> CircuitBuilder { last_evals: &[ExtensionTarget], beta: ExtensionTarget, ) -> ExtensionTarget { - todo!() - // debug_assert_eq!(last_evals.len(), 1 << arity_bits); - // - // let g = F::primitive_root_of_unity(arity_bits); - // - // // The evaluation vector needs to be reordered first. - // let mut evals = last_evals.to_vec(); - // reverse_index_bits_in_place(&mut evals); - // evals.rotate_left(reverse_bits(old_x_index, arity_bits)); - // - // // The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta. - // let points = g - // .powers() - // .zip(evals) - // .map(|(y, e)| ((x * y).into(), e)) - // .collect::>(); - // let barycentric_weights = barycentric_weights(&points); - // interpolate(&points, beta, &barycentric_weights) + debug_assert_eq!(last_evals.len(), 1 << arity_bits); + + let g = F::primitive_root_of_unity(arity_bits); + + // The evaluation vector needs to be reordered first. + let mut evals = last_evals.to_vec(); + reverse_index_bits_in_place(&mut evals); + let mut old_x_index_bits = self.split_le(old_x_index, arity_bits); + old_x_index_bits.reverse(); + self.rotate_left_from_bits(&old_x_index_bits, &evals, arity_bits); + + // The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta. + let points = g + .powers() + .zip(evals) + .map(|(y, e)| { + let yt = self.constant(y); + (self.mul(x, yt), e) + }) + .collect::>(); + + self.interpolate(&points, beta) } fn fri_verify_proof_of_work( @@ -205,8 +209,8 @@ impl, const D: usize> CircuitBuilder { let a = alpha_powers.next(self); ev_zeta_right = self.mul_add_extension(a, t, ev_zeta); } - let zs_interpol = self.interpolate2([(zeta, ev_zeta), (zeta_right, ev_zeta_right)]); - let interpol_val = zs_interpol.eval(self, subgroup_x); + let interpol_val = + self.interpolate2([(zeta, ev_zeta), (zeta_right, ev_zeta_right)], subgroup_x); let numerator = self.sub_extension(ev, interpol_val); let vanish = self.sub_extension(subgroup_x, zeta); let vanish_right = self.sub_extension(subgroup_x, zeta_right); @@ -238,8 +242,8 @@ impl, const D: usize> CircuitBuilder { self.mul_add_extension(a, w, acc) }) .frobenius(self); - let wires_interpol = self.interpolate2([(zeta, wire_eval), (zeta_frob, wire_eval_frob)]); - let interpol_val = wires_interpol.eval(self, subgroup_x); + let interpol_val = + self.interpolate2([(zeta, wire_eval), (zeta_frob, wire_eval_frob)], subgroup_x); let numerator = self.sub_extension(ev, interpol_val); let vanish_frob = self.sub_extension(subgroup_x, zeta_frob); let denominator = self.mul_extension(vanish, vanish_frob); diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index c6c3b2b7..8bf3a797 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -177,7 +177,7 @@ impl, const D: usize> CircuitBuilder { let mut current = base; let one = self.one(); let mut product = one; - let exponent_bits = self.split_le(exponent); + let exponent_bits = self.split_le(exponent, 64); for bit in exponent_bits.into_iter() { product = self.mul_many(&[bit, current, product]); diff --git a/src/gadgets/interpolation.rs b/src/gadgets/interpolation.rs index 6d44cd76..e867620c 100644 --- a/src/gadgets/interpolation.rs +++ b/src/gadgets/interpolation.rs @@ -2,14 +2,104 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; +use crate::gates::interpolation::InterpolationGate; +use crate::target::Target; +use std::marker::PhantomData; impl, const D: usize> CircuitBuilder { /// Interpolate two points. No need for an `InterpolationGate` since the coefficients /// of the linear interpolation polynomial can be easily computed with arithmetic operations. pub fn interpolate2( &mut self, - points: [(ExtensionTarget, ExtensionTarget); 2], - ) -> PolynomialCoeffsExtTarget { - todo!() + interpolation_points: [(ExtensionTarget, ExtensionTarget); 2], + evaluation_point: ExtensionTarget, + ) -> ExtensionTarget { + // a0 -> a1 + // b0 -> b1 + // x -> a1 + (x-a0)*(b1-a1)/(b0-a0) + + let x_m_a0 = self.sub_extension(evaluation_point, interpolation_points[0].0); + let b1_m_a1 = self.sub_extension(interpolation_points[1].1, interpolation_points[0].1); + let b0_m_a0 = self.sub_extension(interpolation_points[1].0, interpolation_points[0].0); + let quotient = self.div_unsafe_extension(b1_m_a1, b0_m_a0); + + self.mul_add_extension(x_m_a0, quotient, interpolation_points[0].1) + } + + /// Interpolate a list of point/evaluation pairs at a given point. + /// Returns the evaluation of the interpolated polynomial at `evaluation_point`. + pub fn interpolate( + &mut self, + interpolation_points: &[(Target, ExtensionTarget)], + evaluation_point: ExtensionTarget, + ) -> ExtensionTarget { + let gate = InterpolationGate:: { + num_points: interpolation_points.len(), + _phantom: PhantomData, + }; + let gate_index = + self.add_gate_no_constants(InterpolationGate::new(interpolation_points.len())); + for (i, &(p, v)) in interpolation_points.iter().enumerate() { + self.route(p, Target::wire(gate_index, gate.wire_point(i))); + self.route_extension( + v, + ExtensionTarget::from_range(gate_index, gate.wires_value(i)), + ); + } + self.route_extension( + evaluation_point, + ExtensionTarget::from_range(gate_index, gate.wires_evaluation_point()), + ); + + ExtensionTarget::from_range(gate_index, gate.wires_evaluation_value()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::circuit_data::CircuitConfig; + use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::quartic::QuarticCrandallField; + use crate::field::extension_field::FieldExtension; + use crate::field::field::Field; + use crate::field::lagrange::{interpolant, interpolate}; + use crate::witness::PartialWitness; + + #[test] + fn test_interpolate() { + type F = CrandallField; + type FF = QuarticCrandallField; + let config = CircuitConfig::large_config(); + let mut builder = CircuitBuilder::::new(config); + + let len = 2; + let points = (0..len) + .map(|_| (F::rand(), FF::rand())) + .collect::>(); + + let homogeneous_points = points + .iter() + .map(|&(a, b)| (>::from_basefield(a), b)) + .collect::>(); + + let true_interpolant = interpolant(&homogeneous_points); + + let z = FF::rand(); + let true_eval = true_interpolant.eval(z); + + let points_target = points + .iter() + .map(|&(p, v)| (builder.constant(p), builder.constant_extension(v))) + .collect::>(); + + let zt = builder.constant_extension(z); + + let eval = builder.interpolate(&points_target, zt); + let true_eval_target = builder.constant_extension(true_eval); + builder.assert_equal_extension(eval, true_eval_target); + + let data = builder.build(); + let proof = data.prove(PartialWitness::new()); } } diff --git a/src/gadgets/rotate.rs b/src/gadgets/rotate.rs index 19230dc0..7afbdfa3 100644 --- a/src/gadgets/rotate.rs +++ b/src/gadgets/rotate.rs @@ -1,4 +1,5 @@ use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field::Field; use crate::gates::base_sum::BaseSumGate; @@ -10,13 +11,24 @@ impl, const D: usize> CircuitBuilder { /// Selects `x` or `y` based on `b`, which is assumed to be binary. /// In particular, this returns `if b { x } else { y }`. /// Note: This does not range-check `b`. - pub fn select(&mut self, b: Target, x: Target, y: Target) -> Target { - let b_y_minus_y = self.mul_sub(b, y, y); - self.mul_sub(b, x, b_y_minus_y) + pub fn select( + &mut self, + b: Target, + x: ExtensionTarget, + y: ExtensionTarget, + ) -> ExtensionTarget { + let b_y_minus_y = self.scalar_mul_sub_extension(b, y, y); + self.scalar_mul_sub_extension(b, x, b_y_minus_y) } /// Left-rotates an array `k` times if `b=1` else return the same array. - pub fn rotate_fixed(&mut self, b: Target, k: usize, v: &[Target], len: usize) -> Vec { + pub fn rotate_left_fixed( + &mut self, + b: Target, + k: usize, + v: &[ExtensionTarget], + len: usize, + ) -> Vec> { let mut res = Vec::new(); for i in 0..len { @@ -29,16 +41,40 @@ impl, const D: usize> CircuitBuilder { /// Left-rotates an array by `num_rotation`. Assumes that `num_rotation` is range-checked to be /// less than `len`. /// Note: We assume `len` is less than 8 since we won't use any arity greater than 8 in FRI (maybe?). - pub fn rotate(&mut self, num_rotation: Target, v: &[Target], len: usize) -> Vec { + pub fn rotate_left_from_bits( + &mut self, + num_rotation_bits: &[Target], + v: &[ExtensionTarget], + len_log: usize, + ) -> Vec> { + debug_assert_eq!(num_rotation_bits.len(), len_log); + let len = 1 << len_log; debug_assert_eq!(v.len(), len); - let bits = self.split_le_base::<2>(num_rotation, 3); + let mut v = v.to_vec(); - let v = self.rotate_fixed(bits[0], 1, v, len); - let v = self.rotate_fixed(bits[1], 2, &v, len); - let v = self.rotate_fixed(bits[2], 4, &v, len); + for i in 0..len_log { + v = self.rotate_left_fixed(num_rotation_bits[i], 1 << i, &v, len); + } v } + + /// Left-rotates an array by `num_rotation`. Assumes that `num_rotation` is range-checked to be + /// less than `len`. + /// Note: We assume `len` is a power of two less than or equal to 8, since we won't use any + /// arity greater than 8 in FRI (maybe?). + pub fn rotate_left( + &mut self, + num_rotation: Target, + v: &[ExtensionTarget], + len_log: usize, + ) -> Vec> { + let len = 1 << len_log; + debug_assert_eq!(v.len(), len); + let bits = self.split_le(num_rotation, len_log); + + self.rotate_left_from_bits(&bits, v, len_log) + } } #[cfg(test)] @@ -46,28 +82,34 @@ mod tests { use super::*; use crate::circuit_data::CircuitConfig; use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::quartic::QuarticCrandallField; - fn real_rotate(num_rotation: usize, v: &[Target]) -> Vec { + fn real_rotate( + num_rotation: usize, + v: &[ExtensionTarget], + ) -> Vec> { let mut res = v.to_vec(); res.rotate_left(num_rotation); res } - fn test_rotate_given_len(len: usize) { + fn test_rotate_given_len(len_log: usize) { type F = CrandallField; + type FF = QuarticCrandallField; + let len = 1 << len_log; let config = CircuitConfig::large_config(); let mut builder = CircuitBuilder::::new(config); let v = (0..len) - .map(|_| builder.constant(F::rand())) - .collect::>(); // 416 = 1532 in base 6. + .map(|_| builder.constant_extension(FF::rand())) + .collect::>(); for i in 0..len { let it = builder.constant(F::from_canonical_usize(i)); let rotated = real_rotate(i, &v); - let purported_rotated = builder.rotate(it, &v, len); + let purported_rotated = builder.rotate_left(it, &v, len_log); for (x, y) in rotated.into_iter().zip(purported_rotated) { - builder.assert_equal(x, y); + builder.assert_equal_extension(x, y); } } @@ -77,8 +119,8 @@ mod tests { #[test] fn test_rotate() { - for i_log in 1..4 { - test_rotate_given_len(1 << i_log); + for len_log in 1..4 { + test_rotate_given_len(len_log); } } } diff --git a/src/gadgets/split_join.rs b/src/gadgets/split_join.rs index f4118d4e..647b0ef5 100644 --- a/src/gadgets/split_join.rs +++ b/src/gadgets/split_join.rs @@ -27,21 +27,22 @@ impl, const D: usize> CircuitBuilder { /// Split the given integer into a list of wires, where each one represents a /// bit of the integer, with little-endian ordering. /// Verifies that the decomposition is correct by using `k` `BaseSum<2>` gates - /// with `k` such that `k*num_routed_bits>=64`. - pub(crate) fn split_le(&mut self, integer: Target) -> Vec { + /// with `k` such that `k*num_routed_wires>=num_bits`. + pub(crate) fn split_le(&mut self, integer: Target, num_bits: usize) -> Vec { let num_limbs = self.config.num_routed_wires - BaseSumGate::<2>::WIRE_LIMBS_START; - let k = ceil_div_usize(64, num_limbs); + let k = ceil_div_usize(num_bits, num_limbs); let gates = (0..k) .map(|_| self.add_gate_no_constants(BaseSumGate::<2>::new(num_limbs))) .collect::>(); - let mut bits = Vec::with_capacity(64); + let mut bits = Vec::with_capacity(num_bits); for &gate in &gates { bits.extend(Target::wires_from_range( gate, BaseSumGate::<2>::WIRE_LIMBS_START..BaseSumGate::<2>::WIRE_LIMBS_START + num_limbs, )); } + bits.drain(num_bits..); let zero = self.zero(); let mut acc = zero; diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index b3fd35d0..8cea3328 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -22,8 +22,8 @@ use crate::witness::PartialWitness; /// given point. #[derive(Clone, Debug)] pub(crate) struct InterpolationGate, const D: usize> { - num_points: usize, - _phantom: PhantomData, + pub num_points: usize, + pub _phantom: PhantomData, } impl, const D: usize> InterpolationGate { diff --git a/src/gates/mod.rs b/src/gates/mod.rs index 3ac7bd74..d013c7cd 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -3,7 +3,7 @@ pub mod base_sum; pub mod constant; pub(crate) mod gate; pub mod gmimc; -mod interpolation; +pub mod interpolation; pub mod mul_extension; pub(crate) mod noop;