diff --git a/src/fri/recursive_verifier.rs b/src/fri/recursive_verifier.rs index c9f15aec..2784ceef 100644 --- a/src/fri/recursive_verifier.rs +++ b/src/fri/recursive_verifier.rs @@ -41,23 +41,14 @@ impl, const D: usize> CircuitBuilder { let coset_start = self.mul(start, x); // The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta. - let g_powers = g + let points = g .powers() - .take(arity) - .map(|y| self.constant(y)) + .zip(evals) + .map(|(y, v)| { + let yc = self.constant(y); + (self.mul(coset_start, yc), v) + }) .collect::>(); - let mut coset = Vec::new(); - for i in 0..arity / 2 { - let res = self.mul_two( - coset_start, - g_powers[2 * i], - coset_start, - g_powers[2 * i + 1], - ); - coset.push(res.0); - coset.push(res.1); - } - let points = coset.into_iter().zip(evals).collect::>(); self.interpolate(&points, beta) } @@ -265,14 +256,11 @@ impl, const D: usize> CircuitBuilder { precomputed_reduced_evals.slope, precomputed_reduced_evals.zs, ); - let (zs_numerator, vanish_zeta_right) = self.sub_two_extension( - zs_composition_eval, - interpol_val, - subgroup_x, - precomputed_reduced_evals.zeta_right, - ); - let (mut sum, zs_denominator) = - alpha.shift_and_mul(sum, vanish_zeta, vanish_zeta_right, self); + let zs_numerator = self.sub_extension(zs_composition_eval, interpol_val); + let vanish_zeta_right = + self.sub_extension(subgroup_x, precomputed_reduced_evals.zeta_right); + sum = alpha.shift(sum, self); + let zs_denominator = self.mul_extension(vanish_zeta, vanish_zeta_right); sum = self.div_add_extension(zs_numerator, zs_denominator, sum); sum @@ -319,17 +307,9 @@ impl, const D: usize> CircuitBuilder { let phi_ext = self.convert_to_ext(phi); let zero = self.zero_extension(); // `subgroup_x = g*phi, vanish_zeta = g*phi - zeta` - let tmp = self.double_arithmetic_extension( - F::ONE, - F::NEG_ONE, - g_ext, - phi_ext, - zero, - g_ext, - phi_ext, - zeta, - ); - (tmp.0 .0[0], tmp.1) + let subgroup_x = self.mul(g, phi); + let vanish_zeta = self.mul_sub_extension(g_ext, phi_ext, zeta); + (subgroup_x, vanish_zeta) }); // old_eval is the last derived evaluation; it will be checked for consistency with its @@ -440,7 +420,8 @@ impl PrecomputedReducedEvalsTarget { let g = builder.constant_extension(F::Extension::primitive_root_of_unity(degree_log)); let zeta_right = builder.mul_extension(g, zeta); - let (numerator, denominator) = builder.sub_two_extension(zs_right, zs, zeta_right, zeta); + let numerator = builder.sub_extension(zs_right, zs); + let denominator = builder.sub_extension(zeta_right, zeta); Self { single, diff --git a/src/gadgets/arithmetic.rs b/src/gadgets/arithmetic.rs index 2c41482e..52b1d660 100644 --- a/src/gadgets/arithmetic.rs +++ b/src/gadgets/arithmetic.rs @@ -86,16 +86,6 @@ impl, const D: usize> CircuitBuilder { self.arithmetic(F::ONE, x, y, F::ZERO, x) } - /// Computes `x * y`. - pub fn mul_two(&mut self, a0: Target, b0: Target, a1: Target, b1: Target) -> (Target, Target) { - let a0_ext = self.convert_to_ext(a0); - let b0_ext = self.convert_to_ext(b0); - let a1_ext = self.convert_to_ext(a1); - let b1_ext = self.convert_to_ext(b1); - let res = self.mul_two_extension(a0_ext, b0_ext, a1_ext, b1_ext); - (res.0 .0[0], res.1 .0[0]) - } - /// Multiply `n` `Target`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s. pub fn mul_many(&mut self, terms: &[Target]) -> Target { let terms_ext = terms diff --git a/src/gadgets/arithmetic_extension.rs b/src/gadgets/arithmetic_extension.rs index 127600a1..a27da240 100644 --- a/src/gadgets/arithmetic_extension.rs +++ b/src/gadgets/arithmetic_extension.rs @@ -6,7 +6,7 @@ use num::Integer; use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget}; use crate::field::extension_field::{Extendable, OEF}; use crate::field::field_types::Field; -use crate::gates::arithmetic::ArithmeticExtensionGate; +use crate::gates::arithmetic::{ArithmeticExtensionGate, NUM_ARITHMETIC_OPS}; use crate::iop::generator::{GeneratedValues, SimpleGenerator}; use crate::iop::target::Target; use crate::iop::wire::Wire; @@ -15,111 +15,25 @@ use crate::plonk::circuit_builder::CircuitBuilder; use crate::util::bits_u64; impl, const D: usize> CircuitBuilder { - pub fn double_arithmetic_extension( - &mut self, - const_0: F, - const_1: F, - first_multiplicand_0: ExtensionTarget, - first_multiplicand_1: ExtensionTarget, - first_addend: ExtensionTarget, - second_multiplicand_0: ExtensionTarget, - second_multiplicand_1: ExtensionTarget, - second_addend: ExtensionTarget, - ) -> (ExtensionTarget, ExtensionTarget) { - if let Some((g, c_0, c_1)) = self.free_arithmetic { - if c_0 == const_0 && c_1 == const_1 { - return self.arithmetic_reusing_gate( - g, - first_multiplicand_0, - first_multiplicand_1, - first_addend, - second_multiplicand_0, - second_multiplicand_1, - second_addend, - ); - } + fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) { + let (gate, i) = self + .free_arithmetic + .get(&(const_0, const_1)) + .copied() + .unwrap_or_else(|| { + let gate = self.add_gate(ArithmeticExtensionGate, vec![const_0, const_1]); + (gate, 0) + }); + + // Update `free_arithmetic` with new values. + if i < NUM_ARITHMETIC_OPS - 1 { + self.free_arithmetic + .insert((const_0, const_1), (gate, i + 1)); + } else { + self.free_arithmetic.remove(&(const_0, const_1)); } - let gate = self.add_gate(ArithmeticExtensionGate, vec![const_0, const_1]); - self.free_arithmetic = Some((gate, const_0, const_1)); - let wire_first_multiplicand_0 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_first_multiplicand_0(), - ); - let wire_first_multiplicand_1 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_first_multiplicand_1(), - ); - let wire_first_addend = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_addend()); - let wire_second_multiplicand_0 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_second_multiplicand_0(), - ); - let wire_second_multiplicand_1 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_second_multiplicand_1(), - ); - let wire_second_addend = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_second_addend()); - let wire_first_output = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_first_output()); - let wire_second_output = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_second_output()); - - self.route_extension(first_multiplicand_0, wire_first_multiplicand_0); - self.route_extension(first_multiplicand_1, wire_first_multiplicand_1); - self.route_extension(first_addend, wire_first_addend); - self.route_extension(second_multiplicand_0, wire_second_multiplicand_0); - self.route_extension(second_multiplicand_1, wire_second_multiplicand_1); - self.route_extension(second_addend, wire_second_addend); - (wire_first_output, wire_second_output) - } - - fn arithmetic_reusing_gate( - &mut self, - gate: usize, - first_multiplicand_0: ExtensionTarget, - first_multiplicand_1: ExtensionTarget, - first_addend: ExtensionTarget, - second_multiplicand_0: ExtensionTarget, - second_multiplicand_1: ExtensionTarget, - second_addend: ExtensionTarget, - ) -> (ExtensionTarget, ExtensionTarget) { - let wire_third_multiplicand_0 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_third_multiplicand_0(), - ); - let wire_third_multiplicand_1 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_third_multiplicand_1(), - ); - let wire_third_addend = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_third_addend()); - let wire_fourth_multiplicand_0 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_fourth_multiplicand_0(), - ); - let wire_fourth_multiplicand_1 = ExtensionTarget::from_range( - gate, - ArithmeticExtensionGate::::wires_fourth_multiplicand_1(), - ); - let wire_fourth_addend = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_fourth_addend()); - let wire_third_output = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_third_output()); - let wire_fourth_output = - ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_fourth_output()); - - self.route_extension(first_multiplicand_0, wire_third_multiplicand_0); - self.route_extension(first_multiplicand_1, wire_third_multiplicand_1); - self.route_extension(first_addend, wire_third_addend); - self.route_extension(second_multiplicand_0, wire_fourth_multiplicand_0); - self.route_extension(second_multiplicand_1, wire_fourth_multiplicand_1); - self.route_extension(second_addend, wire_fourth_addend); - self.free_arithmetic = None; - - (wire_third_output, wire_fourth_output) + (gate, i) } pub fn arithmetic_extension( @@ -141,18 +55,23 @@ impl, const D: usize> CircuitBuilder { return result; } - let zero = self.zero_extension(); - self.double_arithmetic_extension( - const_0, - const_1, - multiplicand_0, - multiplicand_1, - addend, - zero, - zero, - zero, - ) - .0 + let (gate, i) = self.find_arithmetic_gate(const_0, const_1); + let wires_multiplicand_0 = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_ith_multiplicand_0(i), + ); + let wires_multiplicand_1 = ExtensionTarget::from_range( + gate, + ArithmeticExtensionGate::::wires_ith_multiplicand_1(i), + ); + let wires_addend = + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_ith_addend(i)); + + self.route_extension(multiplicand_0, wires_multiplicand_0); + self.route_extension(multiplicand_1, wires_multiplicand_1); + self.route_extension(addend, wires_addend); + + ExtensionTarget::from_range(gate, ArithmeticExtensionGate::::wires_ith_output(i)) } /// Checks for special cases where the value of @@ -233,37 +152,8 @@ impl, const D: usize> CircuitBuilder { pairs: Vec<(ExtensionTarget, ExtensionTarget)>, ) -> ExtensionTarget { let mut acc = starting_acc; - for chunk in pairs.chunks_exact(2) { - let (a0, b0) = chunk[0]; - let (a1, b1) = chunk[1]; - let (gate, range) = if let Some((g, c_0, c_1)) = self.free_arithmetic { - if c_0 == constant && c_1 == F::ONE { - (g, ArithmeticExtensionGate::::wires_third_output()) - } else { - ( - self.num_gates(), - ArithmeticExtensionGate::::wires_first_output(), - ) - } - } else { - ( - self.num_gates(), - ArithmeticExtensionGate::::wires_first_output(), - ) - }; - let first_out = ExtensionTarget::from_range(gate, range); - // let gate = self.num_gates(); - // let first_out = ExtensionTarget::from_range( - // gate, - // ArithmeticExtensionGate::::wires_first_output(), - // ); - acc = self - .double_arithmetic_extension(constant, F::ONE, a0, b0, acc, a1, b1, first_out) - .1; - } - if pairs.len().is_odd() { - let n = pairs.len() - 1; - acc = self.arithmetic_extension(constant, F::ONE, pairs[n].0, pairs[n].1, acc); + for (a, b) in pairs { + acc = self.arithmetic_extension(constant, F::ONE, a, b, acc); } acc } @@ -277,38 +167,15 @@ impl, const D: usize> CircuitBuilder { self.arithmetic_extension(F::ONE, F::ONE, one, a, b) } - /// Returns `(a0+b0, a1+b1)`. - pub fn add_two_extension( - &mut self, - a0: ExtensionTarget, - b0: ExtensionTarget, - a1: ExtensionTarget, - b1: ExtensionTarget, - ) -> (ExtensionTarget, ExtensionTarget) { - let one = self.one_extension(); - self.double_arithmetic_extension(F::ONE, F::ONE, one, a0, b0, one, a1, b1) - } - pub fn add_ext_algebra( &mut self, - a: ExtensionAlgebraTarget, + mut a: ExtensionAlgebraTarget, b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { - // We run two additions in parallel. So `[a0,a1,a2,a3] + [b0,b1,b2,b3]` is computed with two - // `add_two_extension`, first `[a0,a1]+[b0,b1]` then `[a2,a3]+[b2,b3]`. - let mut res = Vec::with_capacity(D); - // We need some extra logic if D is odd. - let d_even = D & (D ^ 1); // = 2 * (D/2) - for mut chunk in &(0..d_even).chunks(2) { - let i = chunk.next().unwrap(); - let j = chunk.next().unwrap(); - let (o0, o1) = self.add_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); - res.extend([o0, o1]); + for i in 0..D { + a.0[i] = self.add_extension(a.0[i], b.0[i]); } - if D.is_odd() { - res.push(self.add_extension(a.0[D - 1], b.0[D - 1])); - } - ExtensionAlgebraTarget(res.try_into().unwrap()) + a } /// Add 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s. @@ -351,35 +218,15 @@ impl, const D: usize> CircuitBuilder { self.arithmetic_extension(F::ONE, F::NEG_ONE, one, a, b) } - pub fn sub_two_extension( - &mut self, - a0: ExtensionTarget, - b0: ExtensionTarget, - a1: ExtensionTarget, - b1: ExtensionTarget, - ) -> (ExtensionTarget, ExtensionTarget) { - let one = self.one_extension(); - self.double_arithmetic_extension(F::ONE, F::NEG_ONE, one, a0, b0, one, a1, b1) - } - pub fn sub_ext_algebra( &mut self, - a: ExtensionAlgebraTarget, + mut a: ExtensionAlgebraTarget, b: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { - // See `add_ext_algebra`. - let mut res = Vec::with_capacity(D); - let d_even = D & (D ^ 1); // = 2 * (D/2) - for mut chunk in &(0..d_even).chunks(2) { - let i = chunk.next().unwrap(); - let j = chunk.next().unwrap(); - let (o0, o1) = self.sub_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]); - res.extend([o0, o1]); + for i in 0..D { + a.0[i] = self.sub_extension(a.0[i], b.0[i]); } - if D.is_odd() { - res.push(self.sub_extension(a.0[D - 1], b.0[D - 1])); - } - ExtensionAlgebraTarget(res.try_into().unwrap()) + a } pub fn mul_extension_with_const( @@ -389,17 +236,7 @@ impl, const D: usize> CircuitBuilder { multiplicand_1: ExtensionTarget, ) -> ExtensionTarget { let zero = self.zero_extension(); - self.double_arithmetic_extension( - const_0, - F::ZERO, - multiplicand_0, - multiplicand_1, - zero, - zero, - zero, - zero, - ) - .0 + self.arithmetic_extension(const_0, F::ZERO, multiplicand_0, multiplicand_1, zero) } pub fn mul_extension( @@ -410,18 +247,6 @@ impl, const D: usize> CircuitBuilder { self.mul_extension_with_const(F::ONE, multiplicand_0, multiplicand_1) } - /// Returns `(a0*b0, a1*b1)`. - pub fn mul_two_extension( - &mut self, - a0: ExtensionTarget, - b0: ExtensionTarget, - a1: ExtensionTarget, - b1: ExtensionTarget, - ) -> (ExtensionTarget, ExtensionTarget) { - let zero = self.zero_extension(); - self.double_arithmetic_extension(F::ONE, F::ZERO, a0, b0, zero, a1, b1, zero) - } - /// Computes `x^2`. pub fn square_extension(&mut self, x: ExtensionTarget) -> ExtensionTarget { self.mul_extension(x, x) @@ -479,25 +304,8 @@ impl, const D: usize> CircuitBuilder { b: ExtensionTarget, c: ExtensionTarget, ) -> ExtensionTarget { - let zero = self.zero_extension(); - let (gate, range) = if let Some((g, c_0, c_1)) = self.free_arithmetic { - if c_0 == F::ONE && c_1 == F::ONE { - (g, ArithmeticExtensionGate::::wires_third_output()) - } else { - ( - self.num_gates(), - ArithmeticExtensionGate::::wires_first_output(), - ) - } - } else { - ( - self.num_gates(), - ArithmeticExtensionGate::::wires_first_output(), - ) - }; - let first_out = ExtensionTarget::from_range(gate, range); - self.double_arithmetic_extension(F::ONE, F::ONE, a, b, zero, c, first_out, zero) - .1 + let tmp = self.mul_extension(a, b); + self.mul_extension(tmp, c) } /// Multiply `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s. @@ -574,22 +382,8 @@ impl, const D: usize> CircuitBuilder { b: ExtensionAlgebraTarget, mut c: ExtensionAlgebraTarget, ) -> ExtensionAlgebraTarget { - for i in 0..D / 2 { - let res = self.double_arithmetic_extension( - F::ONE, - F::ONE, - a, - b.0[2 * i], - c.0[2 * i], - a, - b.0[2 * i + 1], - c.0[2 * i + 1], - ); - c.0[2 * i] = res.0; - c.0[2 * i + 1] = res.1; - } - if D.is_odd() { - c.0[D - 1] = self.arithmetic_extension(F::ONE, F::ONE, a, b.0[D - 1], c.0[D - 1]); + for i in 0..D { + c.0[i] = self.mul_add_extension(a, b.0[i], c.0[i]); } c } @@ -670,11 +464,10 @@ impl, const D: usize> CircuitBuilder { }); // Enforce that x times its purported inverse equals 1. - let (y_inv, res) = - self.double_arithmetic_extension(F::ONE, F::ONE, y, inv, zero, x, inv, z); + let y_inv = self.mul_extension(y, inv); self.assert_equal_extension(y_inv, one); - res + self.mul_add_extension(x, inv, z) } /// Computes `1 / x`. Results in an unsatisfiable instance if `x = 0`. diff --git a/src/gadgets/interpolation.rs b/src/gadgets/interpolation.rs index 9d48a204..6db70b6e 100644 --- a/src/gadgets/interpolation.rs +++ b/src/gadgets/interpolation.rs @@ -5,29 +5,6 @@ use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; impl, const D: usize> CircuitBuilder { - /// Interpolate two points. No need for an `InterpolationGate` since the coefficients - /// of the linear interpolation polynomial can be easily computed with arithmetic operations. - pub fn interpolate2( - &mut self, - interpolation_points: [(ExtensionTarget, ExtensionTarget); 2], - evaluation_point: ExtensionTarget, - ) -> ExtensionTarget { - // a0 -> a1 - // b0 -> b1 - // x -> a1 + (x-a0)*(b1-a1)/(b0-a0) - - let (x_m_a0, b1_m_a1) = self.sub_two_extension( - evaluation_point, - interpolation_points[0].0, - interpolation_points[1].1, - interpolation_points[0].1, - ); - let b0_m_a0 = self.sub_extension(interpolation_points[1].0, interpolation_points[0].0); - let quotient = self.div_extension(b1_m_a1, b0_m_a0); - - self.mul_add_extension(x_m_a0, quotient, interpolation_points[0].1) - } - /// Interpolate a list of point/evaluation pairs at a given point. /// Returns the evaluation of the interpolated polynomial at `evaluation_point`. pub fn interpolate( @@ -108,39 +85,4 @@ mod tests { verify(proof, &data.verifier_only, &data.common) } - - #[test] - fn test_interpolate2() -> Result<()> { - type F = CrandallField; - type FF = QuarticCrandallField; - let config = CircuitConfig::large_zk_config(); - let pw = PartialWitness::new(config.num_wires); - let mut builder = CircuitBuilder::::new(config); - - let len = 2; - let points = (0..len) - .map(|_| (FF::rand(), FF::rand())) - .collect::>(); - - let true_interpolant = interpolant(&points); - - let z = FF::rand(); - let true_eval = true_interpolant.eval(z); - - let points_target = points - .iter() - .map(|&(p, v)| (builder.constant_extension(p), builder.constant_extension(v))) - .collect::>(); - - let zt = builder.constant_extension(z); - - let eval = builder.interpolate2(points_target.try_into().unwrap(), zt); - let true_eval_target = builder.constant_extension(true_eval); - builder.assert_equal_extension(eval, true_eval_target); - - let data = builder.build(); - let proof = data.prove(pw)?; - - verify(proof, &data.verifier_only, &data.common) - } } diff --git a/src/gadgets/select.rs b/src/gadgets/select.rs index fbfdb6e8..8591a631 100644 --- a/src/gadgets/select.rs +++ b/src/gadgets/select.rs @@ -25,25 +25,8 @@ impl, const D: usize> CircuitBuilder { x: ExtensionTarget, y: ExtensionTarget, ) -> ExtensionTarget { - // Holds `by - y`. - let (gate, range) = if let Some((g, c_0, c_1)) = self.free_arithmetic { - if c_0 == F::ONE && c_1 == F::NEG_ONE { - (g, ArithmeticExtensionGate::::wires_third_output()) - } else { - ( - self.num_gates(), - ArithmeticExtensionGate::::wires_first_output(), - ) - } - } else { - ( - self.num_gates(), - ArithmeticExtensionGate::::wires_first_output(), - ) - }; - let first_out = ExtensionTarget::from_range(gate, range); - self.double_arithmetic_extension(F::ONE, F::NEG_ONE, b, y, y, b, x, first_out) - .1 + let tmp = self.mul_sub_extension(b, y, y); + self.mul_sub_extension(b, x, tmp) } /// See `select_ext`. diff --git a/src/gates/arithmetic.rs b/src/gates/arithmetic.rs index 7c780671..5b608519 100644 --- a/src/gates/arithmetic.rs +++ b/src/gates/arithmetic.rs @@ -10,61 +10,25 @@ use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +/// Number of arithmetic operations performed by an arithmetic gate. +pub const NUM_ARITHMETIC_OPS: usize = 4; + /// A gate which can a linear combination `c0*x*y+c1*z` twice with the same `x`. #[derive(Debug)] pub struct ArithmeticExtensionGate; impl ArithmeticExtensionGate { - pub fn wires_first_multiplicand_0() -> Range { - 0..D + pub fn wires_ith_multiplicand_0(i: usize) -> Range { + 4 * D * i..4 * D * i + D } - pub fn wires_first_multiplicand_1() -> Range { - D..2 * D + pub fn wires_ith_multiplicand_1(i: usize) -> Range { + 4 * D * i + D..4 * D * i + 2 * D } - pub fn wires_first_addend() -> Range { - 2 * D..3 * D + pub fn wires_ith_addend(i: usize) -> Range { + 4 * D * i + 2 * D..4 * D * i + 3 * D } - pub fn wires_first_output() -> Range { - 3 * D..4 * D - } - - pub fn wires_second_multiplicand_0() -> Range { - 4 * D..5 * D - } - pub fn wires_second_multiplicand_1() -> Range { - 5 * D..6 * D - } - pub fn wires_second_addend() -> Range { - 6 * D..7 * D - } - pub fn wires_second_output() -> Range { - 7 * D..8 * D - } - - pub fn wires_third_multiplicand_0() -> Range { - 8 * D..9 * D - } - pub fn wires_third_multiplicand_1() -> Range { - 9 * D..10 * D - } - pub fn wires_third_addend() -> Range { - 10 * D..11 * D - } - pub fn wires_third_output() -> Range { - 11 * D..12 * D - } - - pub fn wires_fourth_multiplicand_0() -> Range { - 12 * D..13 * D - } - pub fn wires_fourth_multiplicand_1() -> Range { - 13 * D..14 * D - } - pub fn wires_fourth_addend() -> Range { - 14 * D..15 * D - } - pub fn wires_fourth_output() -> Range { - 15 * D..16 * D + pub fn wires_ith_output(i: usize) -> Range { + 4 * D * i + 3 * D..4 * D * i + 4 * D } } @@ -77,38 +41,18 @@ impl, const D: usize> Gate for ArithmeticExtensionGate let const_0 = vars.local_constants[0]; let const_1 = vars.local_constants[1]; - let first_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_0()); - let first_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_1()); - let first_addend = vars.get_local_ext_algebra(Self::wires_first_addend()); - let second_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_0()); - let second_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_1()); - let second_addend = vars.get_local_ext_algebra(Self::wires_second_addend()); - let third_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_third_multiplicand_0()); - let third_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_third_multiplicand_1()); - let third_addend = vars.get_local_ext_algebra(Self::wires_third_addend()); - let fourth_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_fourth_multiplicand_0()); - let fourth_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_fourth_multiplicand_1()); - let fourth_addend = vars.get_local_ext_algebra(Self::wires_fourth_addend()); - let first_output = vars.get_local_ext_algebra(Self::wires_first_output()); - let second_output = vars.get_local_ext_algebra(Self::wires_second_output()); - let third_output = vars.get_local_ext_algebra(Self::wires_third_output()); - let fourth_output = vars.get_local_ext_algebra(Self::wires_fourth_output()); + let mut constraints = Vec::new(); + for i in 0..NUM_ARITHMETIC_OPS { + let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i)); + let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i)); + let addend = vars.get_local_ext_algebra(Self::wires_ith_addend(i)); + let output = vars.get_local_ext_algebra(Self::wires_ith_output(i)); + let computed_output = + multiplicand_0 * multiplicand_1 * const_0.into() + addend * const_1.into(); - let first_computed_output = first_multiplicand_0 * first_multiplicand_1 * const_0.into() - + first_addend * const_1.into(); - let second_computed_output = second_multiplicand_0 * second_multiplicand_1 * const_0.into() - + second_addend * const_1.into(); - let third_computed_output = third_multiplicand_0 * third_multiplicand_1 * const_0.into() - + third_addend * const_1.into(); - let fourth_computed_output = fourth_multiplicand_0 * fourth_multiplicand_1 * const_0.into() - + fourth_addend * const_1.into(); + constraints.extend((output - computed_output).to_basefield_array()); + } - let mut constraints = (first_output - first_computed_output) - .to_basefield_array() - .to_vec(); - constraints.extend((second_output - second_computed_output).to_basefield_array()); - constraints.extend((third_output - third_computed_output).to_basefield_array()); - constraints.extend((fourth_output - fourth_computed_output).to_basefield_array()); constraints } @@ -116,38 +60,18 @@ impl, const D: usize> Gate for ArithmeticExtensionGate let const_0 = vars.local_constants[0]; let const_1 = vars.local_constants[1]; - let first_multiplicand_0 = vars.get_local_ext(Self::wires_first_multiplicand_0()); - let first_multiplicand_1 = vars.get_local_ext(Self::wires_first_multiplicand_1()); - let first_addend = vars.get_local_ext(Self::wires_first_addend()); - let second_multiplicand_0 = vars.get_local_ext(Self::wires_second_multiplicand_0()); - let second_multiplicand_1 = vars.get_local_ext(Self::wires_second_multiplicand_1()); - let second_addend = vars.get_local_ext(Self::wires_second_addend()); - let third_multiplicand_0 = vars.get_local_ext(Self::wires_third_multiplicand_0()); - let third_multiplicand_1 = vars.get_local_ext(Self::wires_third_multiplicand_1()); - let third_addend = vars.get_local_ext(Self::wires_third_addend()); - let fourth_multiplicand_0 = vars.get_local_ext(Self::wires_fourth_multiplicand_0()); - let fourth_multiplicand_1 = vars.get_local_ext(Self::wires_fourth_multiplicand_1()); - let fourth_addend = vars.get_local_ext(Self::wires_fourth_addend()); - let first_output = vars.get_local_ext(Self::wires_first_output()); - let second_output = vars.get_local_ext(Self::wires_second_output()); - let third_output = vars.get_local_ext(Self::wires_third_output()); - let fourth_output = vars.get_local_ext(Self::wires_fourth_output()); + let mut constraints = Vec::new(); + for i in 0..NUM_ARITHMETIC_OPS { + let multiplicand_0 = vars.get_local_ext(Self::wires_ith_multiplicand_0(i)); + let multiplicand_1 = vars.get_local_ext(Self::wires_ith_multiplicand_1(i)); + let addend = vars.get_local_ext(Self::wires_ith_addend(i)); + let output = vars.get_local_ext(Self::wires_ith_output(i)); + let computed_output = + multiplicand_0 * multiplicand_1 * const_0.into() + addend * const_1.into(); - let first_computed_output = first_multiplicand_0 * first_multiplicand_1 * const_0.into() - + first_addend * const_1.into(); - let second_computed_output = second_multiplicand_0 * second_multiplicand_1 * const_0.into() - + second_addend * const_1.into(); - let third_computed_output = third_multiplicand_0 * third_multiplicand_1 * const_0.into() - + third_addend * const_1.into(); - let fourth_computed_output = fourth_multiplicand_0 * fourth_multiplicand_1 * const_0.into() - + fourth_addend * const_1.into(); + constraints.extend((output - computed_output).to_basefield_array()); + } - let mut constraints = (first_output - first_computed_output) - .to_basefield_array() - .to_vec(); - constraints.extend((second_output - second_computed_output).to_basefield_array()); - constraints.extend((third_output - third_computed_output).to_basefield_array()); - constraints.extend((fourth_output - fourth_computed_output).to_basefield_array()); constraints } @@ -159,61 +83,23 @@ impl, const D: usize> Gate for ArithmeticExtensionGate let const_0 = vars.local_constants[0]; let const_1 = vars.local_constants[1]; - let first_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_0()); - let first_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_1()); - let first_addend = vars.get_local_ext_algebra(Self::wires_first_addend()); - let second_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_0()); - let second_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_1()); - let second_addend = vars.get_local_ext_algebra(Self::wires_second_addend()); - let third_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_third_multiplicand_0()); - let third_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_third_multiplicand_1()); - let third_addend = vars.get_local_ext_algebra(Self::wires_third_addend()); - let fourth_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_fourth_multiplicand_0()); - let fourth_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_fourth_multiplicand_1()); - let fourth_addend = vars.get_local_ext_algebra(Self::wires_fourth_addend()); - let first_output = vars.get_local_ext_algebra(Self::wires_first_output()); - let second_output = vars.get_local_ext_algebra(Self::wires_second_output()); - let third_output = vars.get_local_ext_algebra(Self::wires_third_output()); - let fourth_output = vars.get_local_ext_algebra(Self::wires_fourth_output()); + let mut constraints = Vec::new(); + for i in 0..NUM_ARITHMETIC_OPS { + let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i)); + let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i)); + let addend = vars.get_local_ext_algebra(Self::wires_ith_addend(i)); + let output = vars.get_local_ext_algebra(Self::wires_ith_output(i)); + let computed_output = { + let mul = builder.mul_ext_algebra(multiplicand_0, multiplicand_1); + let scaled_mul = builder.scalar_mul_ext_algebra(const_0, mul); + let scaled_addend = builder.scalar_mul_ext_algebra(const_1, addend); + builder.add_ext_algebra(scaled_mul, scaled_addend) + }; - let first_computed_output = - builder.mul_ext_algebra(first_multiplicand_0, first_multiplicand_1); - let first_computed_output = builder.scalar_mul_ext_algebra(const_0, first_computed_output); - let first_scaled_addend = builder.scalar_mul_ext_algebra(const_1, first_addend); - let first_computed_output = - builder.add_ext_algebra(first_computed_output, first_scaled_addend); + let diff = builder.sub_ext_algebra(output, computed_output); + constraints.extend(diff.to_ext_target_array()); + } - let second_computed_output = - builder.mul_ext_algebra(second_multiplicand_0, second_multiplicand_1); - let second_computed_output = - builder.scalar_mul_ext_algebra(const_0, second_computed_output); - let second_scaled_addend = builder.scalar_mul_ext_algebra(const_1, second_addend); - let second_computed_output = - builder.add_ext_algebra(second_computed_output, second_scaled_addend); - - let third_computed_output = - builder.mul_ext_algebra(third_multiplicand_0, third_multiplicand_1); - let third_computed_output = builder.scalar_mul_ext_algebra(const_0, third_computed_output); - let third_scaled_addend = builder.scalar_mul_ext_algebra(const_1, third_addend); - let third_computed_output = - builder.add_ext_algebra(third_computed_output, third_scaled_addend); - - let fourth_computed_output = - builder.mul_ext_algebra(fourth_multiplicand_0, fourth_multiplicand_1); - let fourth_computed_output = - builder.scalar_mul_ext_algebra(const_0, fourth_computed_output); - let fourth_scaled_addend = builder.scalar_mul_ext_algebra(const_1, fourth_addend); - let fourth_computed_output = - builder.add_ext_algebra(fourth_computed_output, fourth_scaled_addend); - - let diff_0 = builder.sub_ext_algebra(first_output, first_computed_output); - let diff_1 = builder.sub_ext_algebra(second_output, second_computed_output); - let diff_2 = builder.sub_ext_algebra(third_output, third_computed_output); - let diff_3 = builder.sub_ext_algebra(fourth_output, fourth_computed_output); - let mut constraints = diff_0.to_ext_target_array().to_vec(); - constraints.extend(diff_1.to_ext_target_array()); - constraints.extend(diff_2.to_ext_target_array()); - constraints.extend(diff_3.to_ext_target_array()); constraints } @@ -222,24 +108,21 @@ impl, const D: usize> Gate for ArithmeticExtensionGate gate_index: usize, local_constants: &[F], ) -> Vec>> { - let gens = (0..4) - .map(|i| ArithmeticExtensionGenerator { - gate_index, - const_0: local_constants[0], - const_1: local_constants[1], - i, + (0..4) + .map(|i| { + let g: Box> = Box::new(ArithmeticExtensionGenerator { + gate_index, + const_0: local_constants[0], + const_1: local_constants[1], + i, + }); + g }) - .collect::>(); - vec![ - Box::new(gens[0].clone()), - Box::new(gens[1].clone()), - Box::new(gens[2].clone()), - Box::new(gens[3].clone()), - ] + .collect::>() } fn num_wires(&self) -> usize { - 16 * D + NUM_ARITHMETIC_OPS * 4 * D } fn num_constants(&self) -> usize { @@ -251,7 +134,7 @@ impl, const D: usize> Gate for ArithmeticExtensionGate } fn num_constraints(&self) -> usize { - 4 * D + NUM_ARITHMETIC_OPS * D } } @@ -265,7 +148,11 @@ struct ArithmeticExtensionGenerator, const D: usize> { impl, const D: usize> SimpleGenerator for ArithmeticExtensionGenerator { fn dependencies(&self) -> Vec { - (4 * self.i * D..(4 * self.i + 3) * D) + ArithmeticExtensionGate::::wires_ith_multiplicand_0(self.i) + .chain(ArithmeticExtensionGate::::wires_ith_multiplicand_1( + self.i, + )) + .chain(ArithmeticExtensionGate::::wires_ith_addend(self.i)) .map(|i| Target::wire(self.gate_index, i)) .collect() } @@ -276,13 +163,18 @@ impl, const D: usize> SimpleGenerator for ArithmeticExtensio witness.get_extension_target(t) }; - let start = 4 * self.i * D; - let multiplicand_0 = extract_extension(start..start + D); - let multiplicand_1 = extract_extension(start + D..start + 2 * D); - let addend = extract_extension(start + 2 * D..start + 3 * D); + let multiplicand_0 = extract_extension( + ArithmeticExtensionGate::::wires_ith_multiplicand_0(self.i), + ); + let multiplicand_1 = extract_extension( + ArithmeticExtensionGate::::wires_ith_multiplicand_1(self.i), + ); + let addend = extract_extension(ArithmeticExtensionGate::::wires_ith_addend(self.i)); - let output_target = - ExtensionTarget::from_range(self.gate_index, start + 3 * D..start + 4 * D); + let output_target = ExtensionTarget::from_range( + self.gate_index, + ArithmeticExtensionGate::::wires_ith_output(self.i), + ); let computed_output = multiplicand_0 * multiplicand_1 * self.const_0.into() + addend * self.const_1.into(); diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 003c9ab0..070bb81e 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -197,18 +197,8 @@ impl, const D: usize, const R: usize> Gate for GMiMCGate< let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)]; constraints.push(builder.sub_extension(cubing_input, cubing_input_wire)); let f = builder.cube_extension(cubing_input_wire); - // addition_buffer += f - // state[active] -= f - (addition_buffer, state[active]) = builder.double_arithmetic_extension( - F::ONE, - F::ONE, - one, - addition_buffer, - f, - neg_one, - f, - state[active], - ); + addition_buffer = builder.add_extension(addition_buffer, f); + state[active] = builder.sub_extension(state[active], f); } for i in 0..W { diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index 25f760c0..573e0532 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -61,7 +61,9 @@ pub struct CircuitBuilder, const D: usize> { constants_to_targets: HashMap, targets_to_constants: HashMap, - pub(crate) free_arithmetic: Option<(usize, F, F)>, + /// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using + /// these constants with gate index `g` and already using `i` arithmetic operations. + pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>, } impl, const D: usize> CircuitBuilder { @@ -78,7 +80,7 @@ impl, const D: usize> CircuitBuilder { generators: Vec::new(), constants_to_targets: HashMap::new(), targets_to_constants: HashMap::new(), - free_arithmetic: None, + free_arithmetic: HashMap::new(), } } diff --git a/src/plonk/recursive_verifier.rs b/src/plonk/recursive_verifier.rs index a551c809..413294fd 100644 --- a/src/plonk/recursive_verifier.rs +++ b/src/plonk/recursive_verifier.rs @@ -427,9 +427,9 @@ mod tests { zero_knowledge: false, cap_height: 3, fri_config: FriConfig { - proof_of_work_bits: 20, + proof_of_work_bits: 1, reduction_arity_bits: vec![3, 3, 3], - num_query_rounds: 27, + num_query_rounds: 40, }, }; let (proof_with_pis, vd, cd) = { diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index a2a97d4b..73acab42 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -306,17 +306,8 @@ pub(crate) fn eval_vanishing_poly_recursively, const D: usize>( // Holds `k[i] * x`. let mut s_ids = Vec::new(); - for j in 0..common_data.config.num_routed_wires / 2 { - let k_0 = builder.constant(common_data.k_is[2 * j]); - let k_0_ext = builder.convert_to_ext(k_0); - let k_1 = builder.constant(common_data.k_is[2 * j + 1]); - let k_1_ext = builder.convert_to_ext(k_1); - let tmp = builder.mul_two_extension(k_0_ext, x, k_1_ext, x); - s_ids.push(tmp.0); - s_ids.push(tmp.1); - } - if common_data.config.num_routed_wires.is_odd() { - let k = builder.constant(common_data.k_is[common_data.k_is.len() - 1]); + for j in 0..common_data.config.num_routed_wires { + let k = builder.constant(common_data.k_is[j]); let k_ext = builder.convert_to_ext(k); s_ids.push(builder.mul_extension(k_ext, x)); } diff --git a/src/util/reducing.rs b/src/util/reducing.rs index 459454db..863f3246 100644 --- a/src/util/reducing.rs +++ b/src/util/reducing.rs @@ -164,52 +164,15 @@ impl ReducingFactorTarget { where F: Extendable, { - let zero = builder.zero_extension(); let l = terms.len(); self.count += l as u64; let mut terms_vec = terms.to_vec(); - // If needed, we pad the original vector so that it has even length. - if terms_vec.len().is_odd() { - terms_vec.push(zero); - } + let mut acc = terms_vec.pop().unwrap(); terms_vec.reverse(); - let mut acc = zero; - for pair in terms_vec.chunks(2) { - // We will route the output of the first arithmetic operation to the multiplicand of the - // second, i.e. we compute the following: - // out_0 = alpha acc + pair[0] - // acc' = out_1 = alpha out_0 + pair[1] - - let (gate, range) = if let Some((g, c_0, c_1)) = builder.free_arithmetic { - if c_0 == F::ONE && c_1 == F::ONE { - (g, ArithmeticExtensionGate::::wires_third_output()) - } else { - ( - builder.num_gates(), - ArithmeticExtensionGate::::wires_first_output(), - ) - } - } else { - ( - builder.num_gates(), - ArithmeticExtensionGate::::wires_first_output(), - ) - }; - let out_0 = ExtensionTarget::from_range(gate, range); - acc = builder - .double_arithmetic_extension( - F::ONE, - F::ONE, - self.base, - acc, - pair[0], - self.base, - out_0, - pair[1], - ) - .1; + for x in terms_vec { + acc = builder.mul_add_extension(self.base, acc, x); } acc } @@ -227,21 +190,6 @@ impl ReducingFactorTarget { builder.mul_extension(exp, x) } - pub fn shift_and_mul( - &mut self, - x: ExtensionTarget, - a: ExtensionTarget, - b: ExtensionTarget, - builder: &mut CircuitBuilder, - ) -> (ExtensionTarget, ExtensionTarget) - where - F: Extendable, - { - let exp = builder.exp_u64_extension(self.base, self.count); - self.count = 0; - builder.mul_two_extension(exp, x, a, b) - } - pub fn reset(&mut self) { self.count = 0; }