The mother of all arithmetic optimizations

This commit is contained in:
wborgeaud 2021-08-16 10:18:10 +02:00
parent 6ba6201b94
commit b366482866
11 changed files with 157 additions and 645 deletions

View File

@ -41,23 +41,14 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let coset_start = self.mul(start, x);
// The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta.
let g_powers = g
let points = g
.powers()
.take(arity)
.map(|y| self.constant(y))
.zip(evals)
.map(|(y, v)| {
let yc = self.constant(y);
(self.mul(coset_start, yc), v)
})
.collect::<Vec<_>>();
let mut coset = Vec::new();
for i in 0..arity / 2 {
let res = self.mul_two(
coset_start,
g_powers[2 * i],
coset_start,
g_powers[2 * i + 1],
);
coset.push(res.0);
coset.push(res.1);
}
let points = coset.into_iter().zip(evals).collect::<Vec<_>>();
self.interpolate(&points, beta)
}
@ -265,14 +256,11 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
precomputed_reduced_evals.slope,
precomputed_reduced_evals.zs,
);
let (zs_numerator, vanish_zeta_right) = self.sub_two_extension(
zs_composition_eval,
interpol_val,
subgroup_x,
precomputed_reduced_evals.zeta_right,
);
let (mut sum, zs_denominator) =
alpha.shift_and_mul(sum, vanish_zeta, vanish_zeta_right, self);
let zs_numerator = self.sub_extension(zs_composition_eval, interpol_val);
let vanish_zeta_right =
self.sub_extension(subgroup_x, precomputed_reduced_evals.zeta_right);
sum = alpha.shift(sum, self);
let zs_denominator = self.mul_extension(vanish_zeta, vanish_zeta_right);
sum = self.div_add_extension(zs_numerator, zs_denominator, sum);
sum
@ -319,17 +307,9 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let phi_ext = self.convert_to_ext(phi);
let zero = self.zero_extension();
// `subgroup_x = g*phi, vanish_zeta = g*phi - zeta`
let tmp = self.double_arithmetic_extension(
F::ONE,
F::NEG_ONE,
g_ext,
phi_ext,
zero,
g_ext,
phi_ext,
zeta,
);
(tmp.0 .0[0], tmp.1)
let subgroup_x = self.mul(g, phi);
let vanish_zeta = self.mul_sub_extension(g_ext, phi_ext, zeta);
(subgroup_x, vanish_zeta)
});
// old_eval is the last derived evaluation; it will be checked for consistency with its
@ -440,7 +420,8 @@ impl<const D: usize> PrecomputedReducedEvalsTarget<D> {
let g = builder.constant_extension(F::Extension::primitive_root_of_unity(degree_log));
let zeta_right = builder.mul_extension(g, zeta);
let (numerator, denominator) = builder.sub_two_extension(zs_right, zs, zeta_right, zeta);
let numerator = builder.sub_extension(zs_right, zs);
let denominator = builder.sub_extension(zeta_right, zeta);
Self {
single,

View File

@ -86,16 +86,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.arithmetic(F::ONE, x, y, F::ZERO, x)
}
/// Computes `x * y`.
pub fn mul_two(&mut self, a0: Target, b0: Target, a1: Target, b1: Target) -> (Target, Target) {
let a0_ext = self.convert_to_ext(a0);
let b0_ext = self.convert_to_ext(b0);
let a1_ext = self.convert_to_ext(a1);
let b1_ext = self.convert_to_ext(b1);
let res = self.mul_two_extension(a0_ext, b0_ext, a1_ext, b1_ext);
(res.0 .0[0], res.1 .0[0])
}
/// Multiply `n` `Target`s with `ceil(n/2) + 1` `ArithmeticExtensionGate`s.
pub fn mul_many(&mut self, terms: &[Target]) -> Target {
let terms_ext = terms

View File

@ -6,7 +6,7 @@ use num::Integer;
use crate::field::extension_field::target::{ExtensionAlgebraTarget, ExtensionTarget};
use crate::field::extension_field::{Extendable, OEF};
use crate::field::field_types::Field;
use crate::gates::arithmetic::ArithmeticExtensionGate;
use crate::gates::arithmetic::{ArithmeticExtensionGate, NUM_ARITHMETIC_OPS};
use crate::iop::generator::{GeneratedValues, SimpleGenerator};
use crate::iop::target::Target;
use crate::iop::wire::Wire;
@ -15,111 +15,25 @@ use crate::plonk::circuit_builder::CircuitBuilder;
use crate::util::bits_u64;
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pub fn double_arithmetic_extension(
&mut self,
const_0: F,
const_1: F,
first_multiplicand_0: ExtensionTarget<D>,
first_multiplicand_1: ExtensionTarget<D>,
first_addend: ExtensionTarget<D>,
second_multiplicand_0: ExtensionTarget<D>,
second_multiplicand_1: ExtensionTarget<D>,
second_addend: ExtensionTarget<D>,
) -> (ExtensionTarget<D>, ExtensionTarget<D>) {
if let Some((g, c_0, c_1)) = self.free_arithmetic {
if c_0 == const_0 && c_1 == const_1 {
return self.arithmetic_reusing_gate(
g,
first_multiplicand_0,
first_multiplicand_1,
first_addend,
second_multiplicand_0,
second_multiplicand_1,
second_addend,
);
}
fn find_arithmetic_gate(&mut self, const_0: F, const_1: F) -> (usize, usize) {
let (gate, i) = self
.free_arithmetic
.get(&(const_0, const_1))
.copied()
.unwrap_or_else(|| {
let gate = self.add_gate(ArithmeticExtensionGate, vec![const_0, const_1]);
(gate, 0)
});
// Update `free_arithmetic` with new values.
if i < NUM_ARITHMETIC_OPS - 1 {
self.free_arithmetic
.insert((const_0, const_1), (gate, i + 1));
} else {
self.free_arithmetic.remove(&(const_0, const_1));
}
let gate = self.add_gate(ArithmeticExtensionGate, vec![const_0, const_1]);
self.free_arithmetic = Some((gate, const_0, const_1));
let wire_first_multiplicand_0 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_first_multiplicand_0(),
);
let wire_first_multiplicand_1 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_first_multiplicand_1(),
);
let wire_first_addend =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_first_addend());
let wire_second_multiplicand_0 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_second_multiplicand_0(),
);
let wire_second_multiplicand_1 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_second_multiplicand_1(),
);
let wire_second_addend =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_second_addend());
let wire_first_output =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_first_output());
let wire_second_output =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_second_output());
self.route_extension(first_multiplicand_0, wire_first_multiplicand_0);
self.route_extension(first_multiplicand_1, wire_first_multiplicand_1);
self.route_extension(first_addend, wire_first_addend);
self.route_extension(second_multiplicand_0, wire_second_multiplicand_0);
self.route_extension(second_multiplicand_1, wire_second_multiplicand_1);
self.route_extension(second_addend, wire_second_addend);
(wire_first_output, wire_second_output)
}
fn arithmetic_reusing_gate(
&mut self,
gate: usize,
first_multiplicand_0: ExtensionTarget<D>,
first_multiplicand_1: ExtensionTarget<D>,
first_addend: ExtensionTarget<D>,
second_multiplicand_0: ExtensionTarget<D>,
second_multiplicand_1: ExtensionTarget<D>,
second_addend: ExtensionTarget<D>,
) -> (ExtensionTarget<D>, ExtensionTarget<D>) {
let wire_third_multiplicand_0 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_third_multiplicand_0(),
);
let wire_third_multiplicand_1 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_third_multiplicand_1(),
);
let wire_third_addend =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_third_addend());
let wire_fourth_multiplicand_0 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_fourth_multiplicand_0(),
);
let wire_fourth_multiplicand_1 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_fourth_multiplicand_1(),
);
let wire_fourth_addend =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_fourth_addend());
let wire_third_output =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_third_output());
let wire_fourth_output =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_fourth_output());
self.route_extension(first_multiplicand_0, wire_third_multiplicand_0);
self.route_extension(first_multiplicand_1, wire_third_multiplicand_1);
self.route_extension(first_addend, wire_third_addend);
self.route_extension(second_multiplicand_0, wire_fourth_multiplicand_0);
self.route_extension(second_multiplicand_1, wire_fourth_multiplicand_1);
self.route_extension(second_addend, wire_fourth_addend);
self.free_arithmetic = None;
(wire_third_output, wire_fourth_output)
(gate, i)
}
pub fn arithmetic_extension(
@ -141,18 +55,23 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
return result;
}
let zero = self.zero_extension();
self.double_arithmetic_extension(
const_0,
const_1,
multiplicand_0,
multiplicand_1,
addend,
zero,
zero,
zero,
)
.0
let (gate, i) = self.find_arithmetic_gate(const_0, const_1);
let wires_multiplicand_0 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_ith_multiplicand_0(i),
);
let wires_multiplicand_1 = ExtensionTarget::from_range(
gate,
ArithmeticExtensionGate::<D>::wires_ith_multiplicand_1(i),
);
let wires_addend =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_ith_addend(i));
self.route_extension(multiplicand_0, wires_multiplicand_0);
self.route_extension(multiplicand_1, wires_multiplicand_1);
self.route_extension(addend, wires_addend);
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_ith_output(i))
}
/// Checks for special cases where the value of
@ -233,37 +152,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
pairs: Vec<(ExtensionTarget<D>, ExtensionTarget<D>)>,
) -> ExtensionTarget<D> {
let mut acc = starting_acc;
for chunk in pairs.chunks_exact(2) {
let (a0, b0) = chunk[0];
let (a1, b1) = chunk[1];
let (gate, range) = if let Some((g, c_0, c_1)) = self.free_arithmetic {
if c_0 == constant && c_1 == F::ONE {
(g, ArithmeticExtensionGate::<D>::wires_third_output())
} else {
(
self.num_gates(),
ArithmeticExtensionGate::<D>::wires_first_output(),
)
}
} else {
(
self.num_gates(),
ArithmeticExtensionGate::<D>::wires_first_output(),
)
};
let first_out = ExtensionTarget::from_range(gate, range);
// let gate = self.num_gates();
// let first_out = ExtensionTarget::from_range(
// gate,
// ArithmeticExtensionGate::<D>::wires_first_output(),
// );
acc = self
.double_arithmetic_extension(constant, F::ONE, a0, b0, acc, a1, b1, first_out)
.1;
}
if pairs.len().is_odd() {
let n = pairs.len() - 1;
acc = self.arithmetic_extension(constant, F::ONE, pairs[n].0, pairs[n].1, acc);
for (a, b) in pairs {
acc = self.arithmetic_extension(constant, F::ONE, a, b, acc);
}
acc
}
@ -277,38 +167,15 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.arithmetic_extension(F::ONE, F::ONE, one, a, b)
}
/// Returns `(a0+b0, a1+b1)`.
pub fn add_two_extension(
&mut self,
a0: ExtensionTarget<D>,
b0: ExtensionTarget<D>,
a1: ExtensionTarget<D>,
b1: ExtensionTarget<D>,
) -> (ExtensionTarget<D>, ExtensionTarget<D>) {
let one = self.one_extension();
self.double_arithmetic_extension(F::ONE, F::ONE, one, a0, b0, one, a1, b1)
}
pub fn add_ext_algebra(
&mut self,
a: ExtensionAlgebraTarget<D>,
mut a: ExtensionAlgebraTarget<D>,
b: ExtensionAlgebraTarget<D>,
) -> ExtensionAlgebraTarget<D> {
// We run two additions in parallel. So `[a0,a1,a2,a3] + [b0,b1,b2,b3]` is computed with two
// `add_two_extension`, first `[a0,a1]+[b0,b1]` then `[a2,a3]+[b2,b3]`.
let mut res = Vec::with_capacity(D);
// We need some extra logic if D is odd.
let d_even = D & (D ^ 1); // = 2 * (D/2)
for mut chunk in &(0..d_even).chunks(2) {
let i = chunk.next().unwrap();
let j = chunk.next().unwrap();
let (o0, o1) = self.add_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]);
res.extend([o0, o1]);
for i in 0..D {
a.0[i] = self.add_extension(a.0[i], b.0[i]);
}
if D.is_odd() {
res.push(self.add_extension(a.0[D - 1], b.0[D - 1]));
}
ExtensionAlgebraTarget(res.try_into().unwrap())
a
}
/// Add 3 `ExtensionTarget`s with 1 `ArithmeticExtensionGate`s.
@ -351,35 +218,15 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.arithmetic_extension(F::ONE, F::NEG_ONE, one, a, b)
}
pub fn sub_two_extension(
&mut self,
a0: ExtensionTarget<D>,
b0: ExtensionTarget<D>,
a1: ExtensionTarget<D>,
b1: ExtensionTarget<D>,
) -> (ExtensionTarget<D>, ExtensionTarget<D>) {
let one = self.one_extension();
self.double_arithmetic_extension(F::ONE, F::NEG_ONE, one, a0, b0, one, a1, b1)
}
pub fn sub_ext_algebra(
&mut self,
a: ExtensionAlgebraTarget<D>,
mut a: ExtensionAlgebraTarget<D>,
b: ExtensionAlgebraTarget<D>,
) -> ExtensionAlgebraTarget<D> {
// See `add_ext_algebra`.
let mut res = Vec::with_capacity(D);
let d_even = D & (D ^ 1); // = 2 * (D/2)
for mut chunk in &(0..d_even).chunks(2) {
let i = chunk.next().unwrap();
let j = chunk.next().unwrap();
let (o0, o1) = self.sub_two_extension(a.0[i], b.0[i], a.0[j], b.0[j]);
res.extend([o0, o1]);
for i in 0..D {
a.0[i] = self.sub_extension(a.0[i], b.0[i]);
}
if D.is_odd() {
res.push(self.sub_extension(a.0[D - 1], b.0[D - 1]));
}
ExtensionAlgebraTarget(res.try_into().unwrap())
a
}
pub fn mul_extension_with_const(
@ -389,17 +236,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
multiplicand_1: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let zero = self.zero_extension();
self.double_arithmetic_extension(
const_0,
F::ZERO,
multiplicand_0,
multiplicand_1,
zero,
zero,
zero,
zero,
)
.0
self.arithmetic_extension(const_0, F::ZERO, multiplicand_0, multiplicand_1, zero)
}
pub fn mul_extension(
@ -410,18 +247,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.mul_extension_with_const(F::ONE, multiplicand_0, multiplicand_1)
}
/// Returns `(a0*b0, a1*b1)`.
pub fn mul_two_extension(
&mut self,
a0: ExtensionTarget<D>,
b0: ExtensionTarget<D>,
a1: ExtensionTarget<D>,
b1: ExtensionTarget<D>,
) -> (ExtensionTarget<D>, ExtensionTarget<D>) {
let zero = self.zero_extension();
self.double_arithmetic_extension(F::ONE, F::ZERO, a0, b0, zero, a1, b1, zero)
}
/// Computes `x^2`.
pub fn square_extension(&mut self, x: ExtensionTarget<D>) -> ExtensionTarget<D> {
self.mul_extension(x, x)
@ -479,25 +304,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
b: ExtensionTarget<D>,
c: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let zero = self.zero_extension();
let (gate, range) = if let Some((g, c_0, c_1)) = self.free_arithmetic {
if c_0 == F::ONE && c_1 == F::ONE {
(g, ArithmeticExtensionGate::<D>::wires_third_output())
} else {
(
self.num_gates(),
ArithmeticExtensionGate::<D>::wires_first_output(),
)
}
} else {
(
self.num_gates(),
ArithmeticExtensionGate::<D>::wires_first_output(),
)
};
let first_out = ExtensionTarget::from_range(gate, range);
self.double_arithmetic_extension(F::ONE, F::ONE, a, b, zero, c, first_out, zero)
.1
let tmp = self.mul_extension(a, b);
self.mul_extension(tmp, c)
}
/// Multiply `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s.
@ -574,22 +382,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
b: ExtensionAlgebraTarget<D>,
mut c: ExtensionAlgebraTarget<D>,
) -> ExtensionAlgebraTarget<D> {
for i in 0..D / 2 {
let res = self.double_arithmetic_extension(
F::ONE,
F::ONE,
a,
b.0[2 * i],
c.0[2 * i],
a,
b.0[2 * i + 1],
c.0[2 * i + 1],
);
c.0[2 * i] = res.0;
c.0[2 * i + 1] = res.1;
}
if D.is_odd() {
c.0[D - 1] = self.arithmetic_extension(F::ONE, F::ONE, a, b.0[D - 1], c.0[D - 1]);
for i in 0..D {
c.0[i] = self.mul_add_extension(a, b.0[i], c.0[i]);
}
c
}
@ -670,11 +464,10 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
});
// Enforce that x times its purported inverse equals 1.
let (y_inv, res) =
self.double_arithmetic_extension(F::ONE, F::ONE, y, inv, zero, x, inv, z);
let y_inv = self.mul_extension(y, inv);
self.assert_equal_extension(y_inv, one);
res
self.mul_add_extension(x, inv, z)
}
/// Computes `1 / x`. Results in an unsatisfiable instance if `x = 0`.

View File

@ -5,29 +5,6 @@ use crate::iop::target::Target;
use crate::plonk::circuit_builder::CircuitBuilder;
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// Interpolate two points. No need for an `InterpolationGate` since the coefficients
/// of the linear interpolation polynomial can be easily computed with arithmetic operations.
pub fn interpolate2(
&mut self,
interpolation_points: [(ExtensionTarget<D>, ExtensionTarget<D>); 2],
evaluation_point: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
// a0 -> a1
// b0 -> b1
// x -> a1 + (x-a0)*(b1-a1)/(b0-a0)
let (x_m_a0, b1_m_a1) = self.sub_two_extension(
evaluation_point,
interpolation_points[0].0,
interpolation_points[1].1,
interpolation_points[0].1,
);
let b0_m_a0 = self.sub_extension(interpolation_points[1].0, interpolation_points[0].0);
let quotient = self.div_extension(b1_m_a1, b0_m_a0);
self.mul_add_extension(x_m_a0, quotient, interpolation_points[0].1)
}
/// Interpolate a list of point/evaluation pairs at a given point.
/// Returns the evaluation of the interpolated polynomial at `evaluation_point`.
pub fn interpolate(
@ -108,39 +85,4 @@ mod tests {
verify(proof, &data.verifier_only, &data.common)
}
#[test]
fn test_interpolate2() -> Result<()> {
type F = CrandallField;
type FF = QuarticCrandallField;
let config = CircuitConfig::large_zk_config();
let pw = PartialWitness::new(config.num_wires);
let mut builder = CircuitBuilder::<F, 4>::new(config);
let len = 2;
let points = (0..len)
.map(|_| (FF::rand(), FF::rand()))
.collect::<Vec<_>>();
let true_interpolant = interpolant(&points);
let z = FF::rand();
let true_eval = true_interpolant.eval(z);
let points_target = points
.iter()
.map(|&(p, v)| (builder.constant_extension(p), builder.constant_extension(v)))
.collect::<Vec<_>>();
let zt = builder.constant_extension(z);
let eval = builder.interpolate2(points_target.try_into().unwrap(), zt);
let true_eval_target = builder.constant_extension(true_eval);
builder.assert_equal_extension(eval, true_eval_target);
let data = builder.build();
let proof = data.prove(pw)?;
verify(proof, &data.verifier_only, &data.common)
}
}

View File

@ -25,25 +25,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
x: ExtensionTarget<D>,
y: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
// Holds `by - y`.
let (gate, range) = if let Some((g, c_0, c_1)) = self.free_arithmetic {
if c_0 == F::ONE && c_1 == F::NEG_ONE {
(g, ArithmeticExtensionGate::<D>::wires_third_output())
} else {
(
self.num_gates(),
ArithmeticExtensionGate::<D>::wires_first_output(),
)
}
} else {
(
self.num_gates(),
ArithmeticExtensionGate::<D>::wires_first_output(),
)
};
let first_out = ExtensionTarget::from_range(gate, range);
self.double_arithmetic_extension(F::ONE, F::NEG_ONE, b, y, y, b, x, first_out)
.1
let tmp = self.mul_sub_extension(b, y, y);
self.mul_sub_extension(b, x, tmp)
}
/// See `select_ext`.

View File

@ -10,61 +10,25 @@ use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
/// Number of arithmetic operations performed by an arithmetic gate.
pub const NUM_ARITHMETIC_OPS: usize = 4;
/// A gate which can a linear combination `c0*x*y+c1*z` twice with the same `x`.
#[derive(Debug)]
pub struct ArithmeticExtensionGate<const D: usize>;
impl<const D: usize> ArithmeticExtensionGate<D> {
pub fn wires_first_multiplicand_0() -> Range<usize> {
0..D
pub fn wires_ith_multiplicand_0(i: usize) -> Range<usize> {
4 * D * i..4 * D * i + D
}
pub fn wires_first_multiplicand_1() -> Range<usize> {
D..2 * D
pub fn wires_ith_multiplicand_1(i: usize) -> Range<usize> {
4 * D * i + D..4 * D * i + 2 * D
}
pub fn wires_first_addend() -> Range<usize> {
2 * D..3 * D
pub fn wires_ith_addend(i: usize) -> Range<usize> {
4 * D * i + 2 * D..4 * D * i + 3 * D
}
pub fn wires_first_output() -> Range<usize> {
3 * D..4 * D
}
pub fn wires_second_multiplicand_0() -> Range<usize> {
4 * D..5 * D
}
pub fn wires_second_multiplicand_1() -> Range<usize> {
5 * D..6 * D
}
pub fn wires_second_addend() -> Range<usize> {
6 * D..7 * D
}
pub fn wires_second_output() -> Range<usize> {
7 * D..8 * D
}
pub fn wires_third_multiplicand_0() -> Range<usize> {
8 * D..9 * D
}
pub fn wires_third_multiplicand_1() -> Range<usize> {
9 * D..10 * D
}
pub fn wires_third_addend() -> Range<usize> {
10 * D..11 * D
}
pub fn wires_third_output() -> Range<usize> {
11 * D..12 * D
}
pub fn wires_fourth_multiplicand_0() -> Range<usize> {
12 * D..13 * D
}
pub fn wires_fourth_multiplicand_1() -> Range<usize> {
13 * D..14 * D
}
pub fn wires_fourth_addend() -> Range<usize> {
14 * D..15 * D
}
pub fn wires_fourth_output() -> Range<usize> {
15 * D..16 * D
pub fn wires_ith_output(i: usize) -> Range<usize> {
4 * D * i + 3 * D..4 * D * i + 4 * D
}
}
@ -77,38 +41,18 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
let const_0 = vars.local_constants[0];
let const_1 = vars.local_constants[1];
let first_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_0());
let first_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_1());
let first_addend = vars.get_local_ext_algebra(Self::wires_first_addend());
let second_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_0());
let second_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_1());
let second_addend = vars.get_local_ext_algebra(Self::wires_second_addend());
let third_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_third_multiplicand_0());
let third_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_third_multiplicand_1());
let third_addend = vars.get_local_ext_algebra(Self::wires_third_addend());
let fourth_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_fourth_multiplicand_0());
let fourth_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_fourth_multiplicand_1());
let fourth_addend = vars.get_local_ext_algebra(Self::wires_fourth_addend());
let first_output = vars.get_local_ext_algebra(Self::wires_first_output());
let second_output = vars.get_local_ext_algebra(Self::wires_second_output());
let third_output = vars.get_local_ext_algebra(Self::wires_third_output());
let fourth_output = vars.get_local_ext_algebra(Self::wires_fourth_output());
let mut constraints = Vec::new();
for i in 0..NUM_ARITHMETIC_OPS {
let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i));
let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i));
let addend = vars.get_local_ext_algebra(Self::wires_ith_addend(i));
let output = vars.get_local_ext_algebra(Self::wires_ith_output(i));
let computed_output =
multiplicand_0 * multiplicand_1 * const_0.into() + addend * const_1.into();
let first_computed_output = first_multiplicand_0 * first_multiplicand_1 * const_0.into()
+ first_addend * const_1.into();
let second_computed_output = second_multiplicand_0 * second_multiplicand_1 * const_0.into()
+ second_addend * const_1.into();
let third_computed_output = third_multiplicand_0 * third_multiplicand_1 * const_0.into()
+ third_addend * const_1.into();
let fourth_computed_output = fourth_multiplicand_0 * fourth_multiplicand_1 * const_0.into()
+ fourth_addend * const_1.into();
constraints.extend((output - computed_output).to_basefield_array());
}
let mut constraints = (first_output - first_computed_output)
.to_basefield_array()
.to_vec();
constraints.extend((second_output - second_computed_output).to_basefield_array());
constraints.extend((third_output - third_computed_output).to_basefield_array());
constraints.extend((fourth_output - fourth_computed_output).to_basefield_array());
constraints
}
@ -116,38 +60,18 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
let const_0 = vars.local_constants[0];
let const_1 = vars.local_constants[1];
let first_multiplicand_0 = vars.get_local_ext(Self::wires_first_multiplicand_0());
let first_multiplicand_1 = vars.get_local_ext(Self::wires_first_multiplicand_1());
let first_addend = vars.get_local_ext(Self::wires_first_addend());
let second_multiplicand_0 = vars.get_local_ext(Self::wires_second_multiplicand_0());
let second_multiplicand_1 = vars.get_local_ext(Self::wires_second_multiplicand_1());
let second_addend = vars.get_local_ext(Self::wires_second_addend());
let third_multiplicand_0 = vars.get_local_ext(Self::wires_third_multiplicand_0());
let third_multiplicand_1 = vars.get_local_ext(Self::wires_third_multiplicand_1());
let third_addend = vars.get_local_ext(Self::wires_third_addend());
let fourth_multiplicand_0 = vars.get_local_ext(Self::wires_fourth_multiplicand_0());
let fourth_multiplicand_1 = vars.get_local_ext(Self::wires_fourth_multiplicand_1());
let fourth_addend = vars.get_local_ext(Self::wires_fourth_addend());
let first_output = vars.get_local_ext(Self::wires_first_output());
let second_output = vars.get_local_ext(Self::wires_second_output());
let third_output = vars.get_local_ext(Self::wires_third_output());
let fourth_output = vars.get_local_ext(Self::wires_fourth_output());
let mut constraints = Vec::new();
for i in 0..NUM_ARITHMETIC_OPS {
let multiplicand_0 = vars.get_local_ext(Self::wires_ith_multiplicand_0(i));
let multiplicand_1 = vars.get_local_ext(Self::wires_ith_multiplicand_1(i));
let addend = vars.get_local_ext(Self::wires_ith_addend(i));
let output = vars.get_local_ext(Self::wires_ith_output(i));
let computed_output =
multiplicand_0 * multiplicand_1 * const_0.into() + addend * const_1.into();
let first_computed_output = first_multiplicand_0 * first_multiplicand_1 * const_0.into()
+ first_addend * const_1.into();
let second_computed_output = second_multiplicand_0 * second_multiplicand_1 * const_0.into()
+ second_addend * const_1.into();
let third_computed_output = third_multiplicand_0 * third_multiplicand_1 * const_0.into()
+ third_addend * const_1.into();
let fourth_computed_output = fourth_multiplicand_0 * fourth_multiplicand_1 * const_0.into()
+ fourth_addend * const_1.into();
constraints.extend((output - computed_output).to_basefield_array());
}
let mut constraints = (first_output - first_computed_output)
.to_basefield_array()
.to_vec();
constraints.extend((second_output - second_computed_output).to_basefield_array());
constraints.extend((third_output - third_computed_output).to_basefield_array());
constraints.extend((fourth_output - fourth_computed_output).to_basefield_array());
constraints
}
@ -159,61 +83,23 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
let const_0 = vars.local_constants[0];
let const_1 = vars.local_constants[1];
let first_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_0());
let first_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_first_multiplicand_1());
let first_addend = vars.get_local_ext_algebra(Self::wires_first_addend());
let second_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_0());
let second_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_second_multiplicand_1());
let second_addend = vars.get_local_ext_algebra(Self::wires_second_addend());
let third_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_third_multiplicand_0());
let third_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_third_multiplicand_1());
let third_addend = vars.get_local_ext_algebra(Self::wires_third_addend());
let fourth_multiplicand_0 = vars.get_local_ext_algebra(Self::wires_fourth_multiplicand_0());
let fourth_multiplicand_1 = vars.get_local_ext_algebra(Self::wires_fourth_multiplicand_1());
let fourth_addend = vars.get_local_ext_algebra(Self::wires_fourth_addend());
let first_output = vars.get_local_ext_algebra(Self::wires_first_output());
let second_output = vars.get_local_ext_algebra(Self::wires_second_output());
let third_output = vars.get_local_ext_algebra(Self::wires_third_output());
let fourth_output = vars.get_local_ext_algebra(Self::wires_fourth_output());
let mut constraints = Vec::new();
for i in 0..NUM_ARITHMETIC_OPS {
let multiplicand_0 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_0(i));
let multiplicand_1 = vars.get_local_ext_algebra(Self::wires_ith_multiplicand_1(i));
let addend = vars.get_local_ext_algebra(Self::wires_ith_addend(i));
let output = vars.get_local_ext_algebra(Self::wires_ith_output(i));
let computed_output = {
let mul = builder.mul_ext_algebra(multiplicand_0, multiplicand_1);
let scaled_mul = builder.scalar_mul_ext_algebra(const_0, mul);
let scaled_addend = builder.scalar_mul_ext_algebra(const_1, addend);
builder.add_ext_algebra(scaled_mul, scaled_addend)
};
let first_computed_output =
builder.mul_ext_algebra(first_multiplicand_0, first_multiplicand_1);
let first_computed_output = builder.scalar_mul_ext_algebra(const_0, first_computed_output);
let first_scaled_addend = builder.scalar_mul_ext_algebra(const_1, first_addend);
let first_computed_output =
builder.add_ext_algebra(first_computed_output, first_scaled_addend);
let diff = builder.sub_ext_algebra(output, computed_output);
constraints.extend(diff.to_ext_target_array());
}
let second_computed_output =
builder.mul_ext_algebra(second_multiplicand_0, second_multiplicand_1);
let second_computed_output =
builder.scalar_mul_ext_algebra(const_0, second_computed_output);
let second_scaled_addend = builder.scalar_mul_ext_algebra(const_1, second_addend);
let second_computed_output =
builder.add_ext_algebra(second_computed_output, second_scaled_addend);
let third_computed_output =
builder.mul_ext_algebra(third_multiplicand_0, third_multiplicand_1);
let third_computed_output = builder.scalar_mul_ext_algebra(const_0, third_computed_output);
let third_scaled_addend = builder.scalar_mul_ext_algebra(const_1, third_addend);
let third_computed_output =
builder.add_ext_algebra(third_computed_output, third_scaled_addend);
let fourth_computed_output =
builder.mul_ext_algebra(fourth_multiplicand_0, fourth_multiplicand_1);
let fourth_computed_output =
builder.scalar_mul_ext_algebra(const_0, fourth_computed_output);
let fourth_scaled_addend = builder.scalar_mul_ext_algebra(const_1, fourth_addend);
let fourth_computed_output =
builder.add_ext_algebra(fourth_computed_output, fourth_scaled_addend);
let diff_0 = builder.sub_ext_algebra(first_output, first_computed_output);
let diff_1 = builder.sub_ext_algebra(second_output, second_computed_output);
let diff_2 = builder.sub_ext_algebra(third_output, third_computed_output);
let diff_3 = builder.sub_ext_algebra(fourth_output, fourth_computed_output);
let mut constraints = diff_0.to_ext_target_array().to_vec();
constraints.extend(diff_1.to_ext_target_array());
constraints.extend(diff_2.to_ext_target_array());
constraints.extend(diff_3.to_ext_target_array());
constraints
}
@ -222,24 +108,21 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
gate_index: usize,
local_constants: &[F],
) -> Vec<Box<dyn WitnessGenerator<F>>> {
let gens = (0..4)
.map(|i| ArithmeticExtensionGenerator {
gate_index,
const_0: local_constants[0],
const_1: local_constants[1],
i,
(0..4)
.map(|i| {
let g: Box<dyn WitnessGenerator<F>> = Box::new(ArithmeticExtensionGenerator {
gate_index,
const_0: local_constants[0],
const_1: local_constants[1],
i,
});
g
})
.collect::<Vec<_>>();
vec![
Box::new(gens[0].clone()),
Box::new(gens[1].clone()),
Box::new(gens[2].clone()),
Box::new(gens[3].clone()),
]
.collect::<Vec<_>>()
}
fn num_wires(&self) -> usize {
16 * D
NUM_ARITHMETIC_OPS * 4 * D
}
fn num_constants(&self) -> usize {
@ -251,7 +134,7 @@ impl<F: Extendable<D>, const D: usize> Gate<F, D> for ArithmeticExtensionGate<D>
}
fn num_constraints(&self) -> usize {
4 * D
NUM_ARITHMETIC_OPS * D
}
}
@ -265,7 +148,11 @@ struct ArithmeticExtensionGenerator<F: Extendable<D>, const D: usize> {
impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensionGenerator<F, D> {
fn dependencies(&self) -> Vec<Target> {
(4 * self.i * D..(4 * self.i + 3) * D)
ArithmeticExtensionGate::<D>::wires_ith_multiplicand_0(self.i)
.chain(ArithmeticExtensionGate::<D>::wires_ith_multiplicand_1(
self.i,
))
.chain(ArithmeticExtensionGate::<D>::wires_ith_addend(self.i))
.map(|i| Target::wire(self.gate_index, i))
.collect()
}
@ -276,13 +163,18 @@ impl<F: Extendable<D>, const D: usize> SimpleGenerator<F> for ArithmeticExtensio
witness.get_extension_target(t)
};
let start = 4 * self.i * D;
let multiplicand_0 = extract_extension(start..start + D);
let multiplicand_1 = extract_extension(start + D..start + 2 * D);
let addend = extract_extension(start + 2 * D..start + 3 * D);
let multiplicand_0 = extract_extension(
ArithmeticExtensionGate::<D>::wires_ith_multiplicand_0(self.i),
);
let multiplicand_1 = extract_extension(
ArithmeticExtensionGate::<D>::wires_ith_multiplicand_1(self.i),
);
let addend = extract_extension(ArithmeticExtensionGate::<D>::wires_ith_addend(self.i));
let output_target =
ExtensionTarget::from_range(self.gate_index, start + 3 * D..start + 4 * D);
let output_target = ExtensionTarget::from_range(
self.gate_index,
ArithmeticExtensionGate::<D>::wires_ith_output(self.i),
);
let computed_output =
multiplicand_0 * multiplicand_1 * self.const_0.into() + addend * self.const_1.into();

View File

@ -197,18 +197,8 @@ impl<F: Extendable<D>, const D: usize, const R: usize> Gate<F, D> for GMiMCGate<
let cubing_input_wire = vars.local_wires[Self::wire_cubing_input(r)];
constraints.push(builder.sub_extension(cubing_input, cubing_input_wire));
let f = builder.cube_extension(cubing_input_wire);
// addition_buffer += f
// state[active] -= f
(addition_buffer, state[active]) = builder.double_arithmetic_extension(
F::ONE,
F::ONE,
one,
addition_buffer,
f,
neg_one,
f,
state[active],
);
addition_buffer = builder.add_extension(addition_buffer, f);
state[active] = builder.sub_extension(state[active], f);
}
for i in 0..W {

View File

@ -61,7 +61,9 @@ pub struct CircuitBuilder<F: Extendable<D>, const D: usize> {
constants_to_targets: HashMap<F, Target>,
targets_to_constants: HashMap<Target, F>,
pub(crate) free_arithmetic: Option<(usize, F, F)>,
/// A map `(c0, c1) -> (g, i)` from constants `(c0,c1)` to an available arithmetic gate using
/// these constants with gate index `g` and already using `i` arithmetic operations.
pub(crate) free_arithmetic: HashMap<(F, F), (usize, usize)>,
}
impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
@ -78,7 +80,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
generators: Vec::new(),
constants_to_targets: HashMap::new(),
targets_to_constants: HashMap::new(),
free_arithmetic: None,
free_arithmetic: HashMap::new(),
}
}

View File

@ -427,9 +427,9 @@ mod tests {
zero_knowledge: false,
cap_height: 3,
fri_config: FriConfig {
proof_of_work_bits: 20,
proof_of_work_bits: 1,
reduction_arity_bits: vec![3, 3, 3],
num_query_rounds: 27,
num_query_rounds: 40,
},
};
let (proof_with_pis, vd, cd) = {

View File

@ -306,17 +306,8 @@ pub(crate) fn eval_vanishing_poly_recursively<F: Extendable<D>, const D: usize>(
// Holds `k[i] * x`.
let mut s_ids = Vec::new();
for j in 0..common_data.config.num_routed_wires / 2 {
let k_0 = builder.constant(common_data.k_is[2 * j]);
let k_0_ext = builder.convert_to_ext(k_0);
let k_1 = builder.constant(common_data.k_is[2 * j + 1]);
let k_1_ext = builder.convert_to_ext(k_1);
let tmp = builder.mul_two_extension(k_0_ext, x, k_1_ext, x);
s_ids.push(tmp.0);
s_ids.push(tmp.1);
}
if common_data.config.num_routed_wires.is_odd() {
let k = builder.constant(common_data.k_is[common_data.k_is.len() - 1]);
for j in 0..common_data.config.num_routed_wires {
let k = builder.constant(common_data.k_is[j]);
let k_ext = builder.convert_to_ext(k);
s_ids.push(builder.mul_extension(k_ext, x));
}

View File

@ -164,52 +164,15 @@ impl<const D: usize> ReducingFactorTarget<D> {
where
F: Extendable<D>,
{
let zero = builder.zero_extension();
let l = terms.len();
self.count += l as u64;
let mut terms_vec = terms.to_vec();
// If needed, we pad the original vector so that it has even length.
if terms_vec.len().is_odd() {
terms_vec.push(zero);
}
let mut acc = terms_vec.pop().unwrap();
terms_vec.reverse();
let mut acc = zero;
for pair in terms_vec.chunks(2) {
// We will route the output of the first arithmetic operation to the multiplicand of the
// second, i.e. we compute the following:
// out_0 = alpha acc + pair[0]
// acc' = out_1 = alpha out_0 + pair[1]
let (gate, range) = if let Some((g, c_0, c_1)) = builder.free_arithmetic {
if c_0 == F::ONE && c_1 == F::ONE {
(g, ArithmeticExtensionGate::<D>::wires_third_output())
} else {
(
builder.num_gates(),
ArithmeticExtensionGate::<D>::wires_first_output(),
)
}
} else {
(
builder.num_gates(),
ArithmeticExtensionGate::<D>::wires_first_output(),
)
};
let out_0 = ExtensionTarget::from_range(gate, range);
acc = builder
.double_arithmetic_extension(
F::ONE,
F::ONE,
self.base,
acc,
pair[0],
self.base,
out_0,
pair[1],
)
.1;
for x in terms_vec {
acc = builder.mul_add_extension(self.base, acc, x);
}
acc
}
@ -227,21 +190,6 @@ impl<const D: usize> ReducingFactorTarget<D> {
builder.mul_extension(exp, x)
}
pub fn shift_and_mul<F>(
&mut self,
x: ExtensionTarget<D>,
a: ExtensionTarget<D>,
b: ExtensionTarget<D>,
builder: &mut CircuitBuilder<F, D>,
) -> (ExtensionTarget<D>, ExtensionTarget<D>)
where
F: Extendable<D>,
{
let exp = builder.exp_u64_extension(self.base, self.count);
self.count = 0;
builder.mul_two_extension(exp, x, a, b)
}
pub fn reset(&mut self) {
self.count = 0;
}