Merge pull request #173 from mir-protocol/minor_arithmetic_optim

Minor arithmetic optimizations
This commit is contained in:
wborgeaud 2021-08-12 18:27:46 +02:00 committed by GitHub
commit 5bce9ca90d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 14 deletions

View File

@ -152,6 +152,18 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
None
}
/// Returns `a*b + c*d + e`.
pub fn wide_arithmetic_extension(
&mut self,
a: ExtensionTarget<D>,
b: ExtensionTarget<D>,
c: ExtensionTarget<D>,
d: ExtensionTarget<D>,
e: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
self.inner_product_extension(F::ONE, e, vec![(a, b), (c, d)])
}
/// Returns `sum_{(a,b) in vecs} constant * a * b`.
pub fn inner_product_extension(
&mut self,
@ -230,11 +242,7 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
c: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let one = self.one_extension();
let gate = self.num_gates();
let first_out =
ExtensionTarget::from_range(gate, ArithmeticExtensionGate::<D>::wires_first_output());
self.double_arithmetic_extension(F::ONE, F::ONE, one, a, b, one, c, first_out)
.1
self.wide_arithmetic_extension(one, a, one, b, c)
}
/// Add `n` `ExtensionTarget`s with `n/2` `ArithmeticExtensionGate`s.

View File

@ -1,3 +1,5 @@
use num::Integer;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
@ -238,7 +240,7 @@ pub fn evaluate_gate_constraints_recursively<F: Extendable<D>, const D: usize>(
num_gate_constraints: usize,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let mut constraints = vec![builder.zero_extension(); num_gate_constraints];
let mut all_gate_constraints = vec![vec![]; num_gate_constraints];
for gate in gates {
let gate_constraints = with_context!(
builder,
@ -248,9 +250,13 @@ pub fn evaluate_gate_constraints_recursively<F: Extendable<D>, const D: usize>(
.eval_filtered_recursively(builder, vars, &gate.prefix)
);
for (i, c) in gate_constraints.into_iter().enumerate() {
constraints[i] = builder.add_extension(constraints[i], c);
all_gate_constraints[i].push(c);
}
}
let mut constraints = vec![builder.zero_extension(); num_gate_constraints];
for (i, v) in all_gate_constraints.into_iter().enumerate() {
constraints[i] = builder.add_many_extension(&v);
}
constraints
}
@ -274,6 +280,7 @@ pub(crate) fn eval_vanishing_poly_recursively<F: Extendable<D>, const D: usize>(
gammas: &[Target],
alphas: &[Target],
) -> Vec<ExtensionTarget<D>> {
let one = builder.one_extension();
let max_degree = common_data.quotient_degree_factor;
let (num_prods, final_num_prod) = common_data.num_partial_products;
@ -297,6 +304,23 @@ pub(crate) fn eval_vanishing_poly_recursively<F: Extendable<D>, const D: usize>(
let l1_x = eval_l_1_recursively(builder, common_data.degree(), x, x_pow_deg);
// Holds `k[i] * x`.
let mut s_ids = Vec::new();
for j in 0..common_data.config.num_routed_wires / 2 {
let k_0 = builder.constant(common_data.k_is[2 * j]);
let k_0_ext = builder.convert_to_ext(k_0);
let k_1 = builder.constant(common_data.k_is[2 * j + 1]);
let k_1_ext = builder.convert_to_ext(k_1);
let tmp = builder.mul_two_extension(k_0_ext, x, k_1_ext, x);
s_ids.push(tmp.0);
s_ids.push(tmp.1);
}
if common_data.config.num_routed_wires.is_odd() {
let k = builder.constant(common_data.k_is[common_data.k_is.len() - 1]);
let k_ext = builder.convert_to_ext(k);
s_ids.push(builder.mul_extension(k_ext, x));
}
for i in 0..common_data.config.num_challenges {
let z_x = local_zs[i];
let z_gz = next_zs[i];
@ -305,20 +329,19 @@ pub(crate) fn eval_vanishing_poly_recursively<F: Extendable<D>, const D: usize>(
let numerator_values = (0..common_data.config.num_routed_wires)
.map(|j| {
let wire_value = vars.local_wires[j];
let k_i = builder.constant(common_data.k_is[j]);
let s_id = builder.scalar_mul_ext(k_i, x);
let beta_ext = builder.convert_to_ext(betas[i]);
let gamma_ext = builder.convert_to_ext(gammas[i]);
let tmp = builder.scalar_mul_add_extension(betas[i], s_id, wire_value);
builder.add_extension(tmp, gamma_ext)
// `beta * s_id + wire_value + gamma`
builder.wide_arithmetic_extension(beta_ext, s_ids[j], one, wire_value, gamma_ext)
})
.collect::<Vec<_>>();
let denominator_values = (0..common_data.config.num_routed_wires)
.map(|j| {
let wire_value = vars.local_wires[j];
let s_sigma = s_sigmas[j];
let beta_ext = builder.convert_to_ext(betas[i]);
let gamma_ext = builder.convert_to_ext(gammas[i]);
let tmp = builder.scalar_mul_add_extension(betas[i], s_sigma, wire_value);
builder.add_extension(tmp, gamma_ext)
// `beta * s_sigma + wire_value + gamma`
builder.wide_arithmetic_extension(beta_ext, s_sigmas[j], one, wire_value, gamma_ext)
})
.collect::<Vec<_>>();
let quotient_values = (0..common_data.config.num_routed_wires)