Implement Frobenius optimization discussed in #61 comments to avoid calling the Frobenius for every wires.

This commit is contained in:
wborgeaud 2021-06-11 10:27:03 +02:00
parent 4106a47ded
commit 1ebeab2c3a
4 changed files with 37 additions and 5 deletions

View File

@ -28,6 +28,14 @@ pub trait OEF<const D: usize>: FieldExtension<D> {
Self::from_basefield_array(res)
}
/// Repeated Frobenius automorphisms: x -> x^(p^k).
// TODO: Implement this. Is basically the same as `frobenius` above, but using
// `z = W^floor(j*p^k/D)`. I'm not sure there is a closed form for these so
// might require to hardcode them.
fn repeated_frobenius(&self, k: usize) -> Self {
todo!()
}
}
impl<F: Field> OEF<1> for F {

View File

@ -31,6 +31,15 @@ impl<const D: usize> ExtensionTarget<D> {
res.try_into().unwrap()
}
// TODO: Implement this. See comment in `OEF::repeated_frobenius`.
fn repeated_frobenius<F: Extendable<D>>(
&self,
k: usize,
builder: &mut CircuitBuilder<F, D>,
) -> Self {
todo!()
}
pub fn from_range(gate: usize, range: Range<usize>) -> Self {
debug_assert_eq!(range.end - range.start, D);
Target::wires_from_range(gate, range).try_into().unwrap()

View File

@ -7,6 +7,7 @@ use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssi
use num::Integer;
use rand::Rng;
use crate::field::extension_field::OEF;
use crate::util::bits_u64;
/// A finite field with prime order less than 2^64.
@ -283,3 +284,18 @@ impl<F: Field> Iterator for Powers<F> {
Some(result)
}
}
impl<F: Field> Powers<F> {
/// Apply the Frobenius automorphism `k` times.
// TODO: Use `OEF::repeated_frobenius` when it is implemented.
pub fn repeated_frobenius<const D: usize>(self, k: usize) -> Self
where
F: OEF<D>,
{
let Self { base, current } = self;
Self {
base: (0..k).fold(base, |acc, _| acc.frobenius()),
current: (0..k).fold(current, |acc, _| acc.frobenius()),
}
}
}

View File

@ -199,11 +199,10 @@ fn fri_combine_initial<F: Field + Extendable<D>, const D: usize>(
.map(|(&e, a)| a * e.into())
.sum();
let zeta_frob = zeta.frobenius();
let wire_evals_frob = os.wires.iter().map(|e| e.frobenius()).collect::<Vec<_>>();
let wires_interpol = interpolant(&[
(zeta, reduce_with_iter(&os.wires, alpha_powers.clone())),
(zeta_frob, reduce_with_iter(&wire_evals_frob, alpha_powers)),
]);
let ev_zeta = reduce_with_iter(&os.wires, alpha_powers.clone());
let mut alpha_powers_frob = alpha_powers.repeated_frobenius(D - 1);
let ev_zeta_frob = reduce_with_iter(&os.wires, alpha_powers_frob).frobenius();
let wires_interpol = interpolant(&[(zeta, ev_zeta), (zeta_frob, ev_zeta_frob)]);
let numerator = ev - wires_interpol.eval(subgroup_x);
let denominator = (subgroup_x - zeta) * (subgroup_x - zeta_frob);
sum += numerator / denominator;