Merge pull request #491 from mir-protocol/fix_reduction_strategy

Fix reduction strategy
This commit is contained in:
wborgeaud 2022-02-18 17:07:03 +01:00 committed by GitHub
commit 9516e14c3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 14 deletions

View File

@ -35,6 +35,7 @@ impl FriConfig {
let reduction_arity_bits = self.reduction_strategy.reduction_arity_bits(
degree_bits,
self.rate_bits,
self.cap_height,
self.num_query_rounds,
);
FriParams {
@ -67,7 +68,7 @@ pub struct FriParams {
}
impl FriParams {
pub(crate) fn total_arities(&self) -> usize {
pub fn total_arities(&self) -> usize {
self.reduction_arity_bits.iter().sum()
}

View File

@ -9,9 +9,10 @@ pub enum FriReductionStrategy {
Fixed(Vec<usize>),
/// `ConstantArityBits(arity_bits, final_poly_bits)` applies reductions of arity `2^arity_bits`
/// until the polynomial degree is `2^final_poly_bits` or less. This tends to work well in the
/// recursive setting, as it avoids needing multiple configurations of gates used in FRI
/// verification, such as `InterpolationGate`.
/// until the polynomial degree is less than or equal to `2^final_poly_bits` or until any further
/// `arity_bits`-reduction makes the last FRI tree have height less than `cap_height`.
/// This tends to work well in the recursive setting, as it avoids needing multiple configurations
/// of gates used in FRI verification, such as `InterpolationGate`.
ConstantArityBits(usize, usize),
/// `MinSize(opt_max_arity_bits)` searches for an optimal sequence of reduction arities, with an
@ -26,17 +27,20 @@ impl FriReductionStrategy {
&self,
mut degree_bits: usize,
rate_bits: usize,
cap_height: usize,
num_queries: usize,
) -> Vec<usize> {
match self {
FriReductionStrategy::Fixed(reduction_arity_bits) => reduction_arity_bits.to_vec(),
FriReductionStrategy::ConstantArityBits(arity_bits, final_poly_bits) => {
&FriReductionStrategy::ConstantArityBits(arity_bits, final_poly_bits) => {
let mut result = Vec::new();
while degree_bits > *final_poly_bits {
result.push(*arity_bits);
assert!(degree_bits >= *arity_bits);
degree_bits -= *arity_bits;
while degree_bits > final_poly_bits
&& degree_bits + rate_bits - arity_bits >= cap_height
{
result.push(arity_bits);
assert!(degree_bits >= arity_bits);
degree_bits -= arity_bits;
}
result.shrink_to_fit();
result

View File

@ -639,6 +639,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let mut timing = TimingTree::new("preprocess", Level::Trace);
let start = Instant::now();
let rate_bits = self.config.fri_config.rate_bits;
let cap_height = self.config.fri_config.cap_height;
// Hash the public inputs, and route them to a `PublicInputGate` which will enforce that
// those hash wires match the claimed public inputs.
@ -664,7 +665,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let degree_bits = log2_strict(degree);
let fri_params = self.fri_params(degree_bits);
assert!(
fri_params.total_arities() <= degree_bits,
fri_params.total_arities() <= degree_bits + rate_bits - cap_height,
"FRI total reduction arity is too large.",
);
@ -705,7 +706,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
constants_sigmas_vecs,
rate_bits,
PlonkOracle::CONSTANTS_SIGMAS.blinding,
self.config.fri_config.cap_height,
cap_height,
&mut timing,
Some(&fft_root_table),
);

View File

@ -40,6 +40,13 @@ where
{
let degree = trace.len();
let degree_bits = log2_strict(degree);
let fri_params = config.fri_params(degree_bits);
let rate_bits = config.fri_config.rate_bits;
let cap_height = config.fri_config.cap_height;
assert!(
fri_params.total_arities() <= degree_bits + rate_bits - cap_height,
"FRI total reduction arity is too large.",
);
let trace_vecs = trace.iter().map(|row| row.to_vec()).collect_vec();
let trace_col_major: Vec<Vec<F>> = transpose(&trace_vecs);
@ -53,8 +60,6 @@ where
.collect()
);
let rate_bits = config.fri_config.rate_bits;
let cap_height = config.fri_config.cap_height;
let trace_commitment = timed!(
timing,
"compute trace commitment",
@ -160,7 +165,6 @@ where
.chain(permutation_zs_commitment.as_ref())
.chain(once(&quotient_commitment))
.collect_vec();
let fri_params = config.fri_params(degree_bits);
let opening_proof = timed!(
timing,