Merge pull request #175 from mir-protocol/some_more_arithm_opt

Some more arithmetic optimizations
This commit is contained in:
wborgeaud 2021-08-14 11:48:28 +02:00 committed by GitHub
commit 47e9f5461e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 23 deletions

View File

@ -139,7 +139,13 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let precomputed_reduced_evals = with_context!(
self,
"precompute reduced evaluations",
PrecomputedReducedEvalsTarget::from_os_and_alpha(os, alpha, self)
PrecomputedReducedEvalsTarget::from_os_and_alpha(
os,
alpha,
common_data.degree_bits,
zeta,
self
)
);
for (i, round_proof) in proof.query_round_proofs.iter().enumerate() {
@ -204,8 +210,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
&mut self,
proof: &FriInitialTreeProofTarget,
alpha: ExtensionTarget<D>,
zeta: ExtensionTarget<D>,
subgroup_x: Target,
vanish_zeta: ExtensionTarget<D>,
precomputed_reduced_evals: PrecomputedReducedEvalsTarget<D>,
common_data: &CommonCircuitData<F, D>,
) -> ExtensionTarget<D> {
@ -218,7 +224,6 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
- config.rate_bits
);
let subgroup_x = self.convert_to_ext(subgroup_x);
let vanish_zeta = self.sub_extension(subgroup_x, zeta);
let mut alpha = ReducingFactorTarget::new(alpha);
let mut sum = self.zero_extension();
@ -255,19 +260,19 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
.collect::<Vec<_>>();
let zs_composition_eval = alpha.reduce_base(&zs_evals, self);
let g = self.constant_extension(F::Extension::primitive_root_of_unity(degree_log));
let zeta_right = self.mul_extension(g, zeta);
let interpol_val = self.interpolate2(
[
(zeta, precomputed_reduced_evals.zs),
(zeta_right, precomputed_reduced_evals.zs_right),
],
subgroup_x,
let interpol_val = self.mul_add_extension(
vanish_zeta,
precomputed_reduced_evals.slope,
precomputed_reduced_evals.zs,
);
let (zs_numerator, vanish_zeta_right) =
self.sub_two_extension(zs_composition_eval, interpol_val, subgroup_x, zeta_right);
let zs_denominator = self.mul_extension(vanish_zeta, vanish_zeta_right);
sum = alpha.shift(sum, self);
let (zs_numerator, vanish_zeta_right) = self.sub_two_extension(
zs_composition_eval,
interpol_val,
subgroup_x,
precomputed_reduced_evals.zeta_right,
);
let (mut sum, zs_denominator) =
alpha.shift_and_mul(sum, vanish_zeta, vanish_zeta_right, self);
sum = self.div_add_extension(zs_numerator, zs_denominator, sum);
sum
@ -307,12 +312,26 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
);
// `subgroup_x` is `subgroup[x_index]`, i.e., the actual field element in the domain.
let mut subgroup_x = with_context!(self, "compute x from its index", {
let g = self.constant(F::MULTIPLICATIVE_GROUP_GENERATOR);
let (mut subgroup_x, vanish_zeta) = with_context!(self, "compute x from its index", {
let g = self.constant(F::coset_shift());
let phi = self.constant(F::primitive_root_of_unity(n_log));
let phi = self.exp_from_bits(phi, x_index_bits.iter().rev());
self.mul(g, phi)
let g_ext = self.convert_to_ext(g);
let phi_ext = self.convert_to_ext(phi);
let zero = self.zero_extension();
// `subgroup_x = g*phi, vanish_zeta = g*phi - zeta`
let tmp = self.double_arithmetic_extension(
F::ONE,
F::NEG_ONE,
g_ext,
phi_ext,
zero,
g_ext,
phi_ext,
zeta,
);
(tmp.0 .0[0], tmp.1)
});
// old_eval is the last derived evaluation; it will be checked for consistency with its
@ -323,8 +342,8 @@ impl<F: Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.fri_combine_initial(
&round_proof.initial_trees_proof,
alpha,
zeta,
subgroup_x,
vanish_zeta,
precomputed_reduced_evals,
common_data,
)
@ -393,12 +412,17 @@ struct PrecomputedReducedEvalsTarget<const D: usize> {
pub single: ExtensionTarget<D>,
pub zs: ExtensionTarget<D>,
pub zs_right: ExtensionTarget<D>,
/// Slope of the line from `(zeta, zs)` to `(zeta_right, zs_right)`.
pub slope: ExtensionTarget<D>,
pub zeta_right: ExtensionTarget<D>,
}
impl<const D: usize> PrecomputedReducedEvalsTarget<D> {
fn from_os_and_alpha<F: Extendable<D>>(
os: &OpeningSetTarget<D>,
alpha: ExtensionTarget<D>,
degree_log: usize,
zeta: ExtensionTarget<D>,
builder: &mut CircuitBuilder<F, D>,
) -> Self {
let mut alpha = ReducingFactorTarget::new(alpha);
@ -416,10 +440,16 @@ impl<const D: usize> PrecomputedReducedEvalsTarget<D> {
let zs = alpha.reduce(&os.plonk_zs, builder);
let zs_right = alpha.reduce(&os.plonk_zs_right, builder);
let g = builder.constant_extension(F::Extension::primitive_root_of_unity(degree_log));
let zeta_right = builder.mul_extension(g, zeta);
let (numerator, denominator) = builder.sub_two_extension(zs_right, zs, zeta_right, zeta);
Self {
single,
zs,
zs_right,
slope: builder.div_extension(numerator, denominator),
zeta_right,
}
}
}

View File

@ -363,8 +363,9 @@ pub(crate) fn eval_vanishing_poly_recursively<F: Extendable<D>, const D: usize>(
.chunks(max_degree)
.zip(partial_product_check.iter_mut())
.for_each(|(d, q)| {
let tmp = builder.mul_many_extension(d);
*q = builder.mul_extension(*q, tmp);
let mut v = d.to_vec();
v.push(*q);
*q = builder.mul_many_extension(&v);
});
vanishing_partial_products_terms.extend(partial_product_check);

View File

@ -211,9 +211,25 @@ impl<const D: usize> ReducingFactorTarget<D> {
F: Extendable<D>,
{
let exp = builder.exp_u64_extension(self.base, self.count);
let tmp = builder.mul_extension(exp, x);
self.count = 0;
tmp
builder.mul_extension(exp, x)
}
/// Returns `(self.shift(x), a*b)`.
/// Used to take advantage of the second arithmetic operation in the `ArithmeticExtensionGate`.
pub fn shift_and_mul<F>(
&mut self,
x: ExtensionTarget<D>,
a: ExtensionTarget<D>,
b: ExtensionTarget<D>,
builder: &mut CircuitBuilder<F, D>,
) -> (ExtensionTarget<D>, ExtensionTarget<D>)
where
F: Extendable<D>,
{
let exp = builder.exp_u64_extension(self.base, self.count);
self.count = 0;
builder.mul_two_extension(exp, x, a, b)
}
pub fn reset(&mut self) {