Working reduction arity

This commit is contained in:
wborgeaud 2021-04-26 18:24:57 +02:00
parent 406df34990
commit 67aa704f6a
3 changed files with 110 additions and 68 deletions

View File

@ -29,7 +29,7 @@ pub(crate) fn interpolant<F: Field>(points: &[(F, F)]) -> PolynomialCoeffs<F> {
/// Interpolate the polynomial defined by an arbitrary set of (point, value) pairs at the given /// Interpolate the polynomial defined by an arbitrary set of (point, value) pairs at the given
/// point `x`. /// point `x`.
fn interpolate<F: Field>(points: &[(F, F)], x: F, barycentric_weights: &[F]) -> F { pub fn interpolate<F: Field>(points: &[(F, F)], x: F, barycentric_weights: &[F]) -> F {
// If x is in the list of points, the Lagrange formula would divide by zero. // If x is in the list of points, the Lagrange formula would divide by zero.
for &(x_i, y_i) in points { for &(x_i, y_i) in points {
if x_i == x { if x_i == x {
@ -37,7 +37,7 @@ fn interpolate<F: Field>(points: &[(F, F)], x: F, barycentric_weights: &[F]) ->
} }
} }
let l_x: F = points.iter().map(|&(x_i, y_i)| x - x_i).product(); let l_x: F = points.iter().map(|&(x_i, _y_i)| x - x_i).product();
let sum = (0..points.len()) let sum = (0..points.len())
.map(|i| { .map(|i| {
@ -51,7 +51,7 @@ fn interpolate<F: Field>(points: &[(F, F)], x: F, barycentric_weights: &[F]) ->
l_x * sum l_x * sum
} }
fn barycentric_weights<F: Field>(points: &[(F, F)]) -> Vec<F> { pub fn barycentric_weights<F: Field>(points: &[(F, F)]) -> Vec<F> {
let n = points.len(); let n = points.len();
(0..n) (0..n)
.map(|i| { .map(|i| {

View File

@ -1,12 +1,13 @@
use crate::field::fft::fft; use crate::field::fft::fft;
use crate::field::field::Field; use crate::field::field::Field;
use crate::field::lagrange::{barycentric_weights, interpolate};
use crate::hash::hash_n_to_1; use crate::hash::hash_n_to_1;
use crate::merkle_proofs::verify_merkle_proof; use crate::merkle_proofs::verify_merkle_proof_subtree;
use crate::merkle_tree::MerkleTree; use crate::merkle_tree::MerkleTree;
use crate::plonk_challenger::Challenger; use crate::plonk_challenger::Challenger;
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::proof::{FriEvaluations, FriMerkleProofs, FriProof, FriQueryRound, Hash}; use crate::proof::{FriEvaluations, FriMerkleProofs, FriProof, FriQueryRound, Hash};
use crate::util::log2_strict; use crate::util::{log2_strict, reverse_bits};
use anyhow::{ensure, Result}; use anyhow::{ensure, Result};
/// Somewhat arbitrary. Smaller values will increase delta, but with diminishing returns, /// Somewhat arbitrary. Smaller values will increase delta, but with diminishing returns,
@ -16,16 +17,14 @@ const EPSILON: f64 = 0.01;
struct FriConfig { struct FriConfig {
proof_of_work_bits: u32, proof_of_work_bits: u32,
rate_bits: usize,
/// The arity of each FRI reduction step, expressed (i.e. the log2 of the actual arity). /// The arity of each FRI reduction step, expressed (i.e. the log2 of the actual arity).
/// For example, `[3, 2, 1]` would describe a FRI reduction tree with 8-to-1 reduction, then /// For example, `[3, 2, 1]` would describe a FRI reduction tree with 8-to-1 reduction, then
/// a 4-to-1 reduction, then a 2-to-1 reduction. After these reductions, the reduced polynomial /// a 4-to-1 reduction, then a 2-to-1 reduction. After these reductions, the reduced polynomial
/// is sent directly. /// is sent directly.
reduction_arity_bits: Vec<usize>, reduction_arity_bits: Vec<usize>,
/// Number of reductions in the FRI protocol. So if the original domain has size `2^n`,
/// then the final domain will have size `2^(n-reduction_count)`.
reduction_count: usize,
/// Number of query rounds to perform. /// Number of query rounds to perform.
num_query_rounds: usize, num_query_rounds: usize,
} }
@ -101,8 +100,9 @@ fn fri_committed_trees<F: Field>(
challenger.observe_hash(&trees[0].root); challenger.observe_hash(&trees[0].root);
for &arity_bits in &config.reduction_arity_bits { let num_reductions = config.reduction_arity_bits.len();
let arity = 1 << arity_bits; for i in 0..num_reductions {
let arity = 1 << config.reduction_arity_bits[i];
let beta = challenger.get_challenge(); let beta = challenger.get_challenge();
// P(x) = sum_{i<r} x^i * P_i(x^r) becomes sum_{i<r} beta^i * P_i(x). // P(x) = sum_{i<r} x^i * P_i(x^r) becomes sum_{i<r} beta^i * P_i(x).
coeffs = PolynomialCoeffs::new( coeffs = PolynomialCoeffs::new(
@ -112,12 +112,16 @@ fn fri_committed_trees<F: Field>(
.map(|chunk| chunk.iter().rev().fold(F::ZERO, |acc, &c| acc * beta + c)) .map(|chunk| chunk.iter().rev().fold(F::ZERO, |acc, &c| acc * beta + c))
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
); );
if i == num_reductions - 1 {
break;
}
values = fft(coeffs.clone()); values = fft(coeffs.clone());
let tree = MerkleTree::new(values.values.iter().map(|&v| vec![v]).collect(), true); let tree = MerkleTree::new(values.values.iter().map(|&v| vec![v]).collect(), true);
challenger.observe_hash(&tree.root); challenger.observe_hash(&tree.root);
trees.push(tree); trees.push(tree);
} }
challenger.observe_elements(&coeffs.coeffs);
(trees, coeffs) (trees, coeffs)
} }
@ -212,7 +216,7 @@ fn fri_query_round<F: Field>(
if i == 0 { if i == 0 {
// For the first layer, we need to send the evaluation at `x` too. // For the first layer, we need to send the evaluation at `x` too.
evals.evals.push( evals.evals.push(
roots_coset_indices[1..] roots_coset_indices
.iter() .iter()
.map(|&index| tree.get(index)[0]) .map(|&index| tree.get(index)[0])
.collect(), .collect(),
@ -221,16 +225,12 @@ fn fri_query_round<F: Field>(
// For the other layers, we don't need to send the evaluation at `x`, since it can // For the other layers, we don't need to send the evaluation at `x`, since it can
// be inferred by the verifier. See the `compute_evaluation` function. // be inferred by the verifier. See the `compute_evaluation` function.
evals.evals.push( evals.evals.push(
roots_coset_indices roots_coset_indices[1..]
.iter() .iter()
.map(|&index| tree.get(index)[0]) .map(|&index| tree.get(index)[0])
.collect(), .collect(),
); );
} }
dbg!(roots_coset_indices
.into_iter()
.map(|i| i & ((1 << log2_strict(next_domain_size)) - 1))
.collect::<Vec<_>>());
merkle_proofs.proofs.push(tree.prove_subtree( merkle_proofs.proofs.push(tree.prove_subtree(
x_index & ((1 << log2_strict(next_domain_size)) - 1), x_index & ((1 << log2_strict(next_domain_size)) - 1),
arity_bits, arity_bits,
@ -244,34 +244,46 @@ fn fri_query_round<F: Field>(
}); });
} }
/// Computes P'(x^2) from {P(x*g^i)}_(i=0..arity), where g is a `arity`-th root of unity and P' is the FRI reduced polynomial. /// Computes P'(x^arity) from {P(x*g^i)}_(i=0..arity), where g is a `arity`-th root of unity and P' is the FRI reduced polynomial.
fn compute_evaluation<F: Field>(x: F, arity_bits: usize, last_evals: Vec<F>, beta: F) -> F { fn compute_evaluation<F: Field>(x: F, arity_bits: usize, last_evals: &[F], beta: F) -> F {
let g = F::primitive_root_of_unity(arity_bits); let g = F::primitive_root_of_unity(arity_bits);
let points = g let points = g
.powers() .powers()
.zip(last_evals)
.take(1 << arity_bits) .take(1 << arity_bits)
.map(|y| x * y) .map(|(y, &e)| (x * y, e))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
(last_e_x + last_e_x_minus) / F::TWO + beta * (last_e_x - last_e_x_minus) / (F::TWO * x) let barycentric_weights = barycentric_weights(&points);
interpolate(&points, beta, &barycentric_weights)
} }
fn verify_fri_proof<F: Field>( fn verify_fri_proof<F: Field>(
purported_degree_log: usize,
proof: &FriProof<F>, proof: &FriProof<F>,
challenger: &mut Challenger<F>, challenger: &mut Challenger<F>,
config: &FriConfig, config: &FriConfig,
) -> Result<()> { ) -> Result<()> {
let total_arities = config.reduction_arity_bits.iter().sum::<usize>();
ensure!(
purported_degree_log
== log2_strict(proof.final_poly.len()) + total_arities - config.rate_bits,
"Final polynomial has wrong degree."
);
// Size of the LDE domain. // Size of the LDE domain.
let n = proof.final_poly.len() << config.reduction_count; let n = proof.final_poly.len() << total_arities;
// Recover the random betas used in the FRI reductions. // Recover the random betas used in the FRI reductions.
let betas = proof.commit_phase_merkle_roots[..proof.commit_phase_merkle_roots.len() - 1] // let betas = proof.commit_phase_merkle_roots[..proof.commit_phase_merkle_roots.len() - 1]
let betas = proof
.commit_phase_merkle_roots
.iter() .iter()
.map(|root| { .map(|root| {
challenger.observe_hash(root); challenger.observe_hash(root);
challenger.get_challenge() challenger.get_challenge()
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
challenger.observe_hash(proof.commit_phase_merkle_roots.last().unwrap()); // challenger.observe_hash(proof.commit_phase_merkle_roots.last().unwrap());
challenger.observe_elements(&proof.final_poly.coeffs);
// Check PoW. // Check PoW.
fri_verify_proof_of_work(proof, challenger, config)?; fri_verify_proof_of_work(proof, challenger, config)?;
@ -281,7 +293,7 @@ fn verify_fri_proof<F: Field>(
"Number of query rounds does not match config." "Number of query rounds does not match config."
); );
ensure!( ensure!(
config.reduction_count > 0, !config.reduction_arity_bits.is_empty(),
"Number of reductions should be non-zero." "Number of reductions should be non-zero."
); );
@ -293,53 +305,66 @@ fn verify_fri_proof<F: Field>(
let mut x_index = x.to_canonical_u64() as usize; let mut x_index = x.to_canonical_u64() as usize;
// `subgroup_x` is `subgroup[x_index]`, i.e., the actual field element in the domain. // `subgroup_x` is `subgroup[x_index]`, i.e., the actual field element in the domain.
let mut subgroup_x = F::primitive_root_of_unity(log2_strict(n)).exp_usize(x_index % n); let mut subgroup_x = F::primitive_root_of_unity(log2_strict(n)).exp_usize(x_index % n);
for i in 0..config.reduction_count { for (i, &arity_bits) in config.reduction_arity_bits.iter().enumerate() {
let arity = 1 << arity_bits;
x_index %= domain_size; x_index %= domain_size;
let next_domain_size = domain_size >> 1; let next_domain_size = domain_size >> arity_bits;
let minus_x_index = (next_domain_size + x_index) % domain_size; let roots_coset_indices =
let (e_x, e_x_minus, merkle_proof, merkle_proof_minus) = if i == 0 { index_roots_coset(x_index, next_domain_size, domain_size, arity);
let (e_x, e_x_minus) = round_proof.evals.first_layer; if i == 0 {
let (merkle_proof, merkle_proof_minus) = &round_proof.merkle_proofs.proofs[i]; let evals = round_proof.evals.evals[0].clone();
e_xs.push((e_x, e_x_minus)); e_xs.push(evals);
(e_x, e_x_minus, merkle_proof, merkle_proof_minus)
} else { } else {
let (last_e_x, last_e_x_minus) = e_xs[i - 1]; let last_evals = &e_xs[i - 1];
let e_x = compute_evaluation(subgroup_x, last_e_x, last_e_x_minus, betas[i - 1]); let e_x = compute_evaluation(
let e_x_minus = round_proof.evals.rest[i - 1]; subgroup_x,
let (merkle_proof, merkle_proof_minus) = &round_proof.merkle_proofs.proofs[i]; config.reduction_arity_bits[i - 1],
e_xs.push((e_x, e_x_minus)); last_evals,
(e_x, e_x_minus, merkle_proof, merkle_proof_minus) betas[i - 1],
);
let mut evals = round_proof.evals.evals[i].clone();
evals.insert(0, e_x);
e_xs.push(evals);
}; };
verify_merkle_proof( let sorted_evals = {
vec![e_x], let mut sorted_evals_enumerate = e_xs[i].iter().enumerate().collect::<Vec<_>>();
x_index, sorted_evals_enumerate.sort_by_key(|&(j, _)| {
reverse_bits(roots_coset_indices[j], log2_strict(domain_size))
});
sorted_evals_enumerate
.into_iter()
.map(|(_, &e)| vec![e])
.collect()
};
verify_merkle_proof_subtree(
sorted_evals,
x_index & ((1 << log2_strict(next_domain_size)) - 1),
proof.commit_phase_merkle_roots[i], proof.commit_phase_merkle_roots[i],
merkle_proof, &round_proof.merkle_proofs.proofs[i],
true,
)?;
verify_merkle_proof(
vec![e_x_minus],
minus_x_index,
proof.commit_phase_merkle_roots[i],
merkle_proof_minus,
true, true,
)?; )?;
if i > 0 { if i > 0 {
subgroup_x = subgroup_x.square(); for _ in 0..config.reduction_arity_bits[i - 1] {
subgroup_x = subgroup_x.square();
}
} }
domain_size = next_domain_size; domain_size = next_domain_size;
} }
let (last_e_x, last_e_x_minus) = e_xs[config.reduction_count - 1]; let last_evals = e_xs.last().unwrap();
let final_arity_bits = *config.reduction_arity_bits.last().unwrap();
let purported_eval = compute_evaluation( let purported_eval = compute_evaluation(
subgroup_x, subgroup_x,
last_e_x, final_arity_bits,
last_e_x_minus, last_evals,
betas[config.reduction_count - 1], *betas.last().unwrap(),
); );
for _ in 0..final_arity_bits {
subgroup_x = subgroup_x.square();
}
// Final check of FRI. After all the reductions, we check that the final polynomial is equal // Final check of FRI. After all the reductions, we check that the final polynomial is equal
// to the one sent by the prover. // to the one sent by the prover.
ensure!( ensure!(
proof.final_poly.eval(subgroup_x.square()) == purported_eval, proof.final_poly.eval(subgroup_x) == purported_eval,
"Final polynomial evaluation is invalid." "Final polynomial evaluation is invalid."
); );
} }
@ -353,41 +378,57 @@ mod tests {
use crate::field::crandall_field::CrandallField; use crate::field::crandall_field::CrandallField;
use crate::field::fft::ifft; use crate::field::fft::ifft;
use anyhow::Result; use anyhow::Result;
use rand::Rng;
fn test_fri( fn test_fri(
degree: usize, degree_log: usize,
rate_bits: usize, rate_bits: usize,
reduction_count: usize, reduction_arity_bits: Vec<usize>,
num_query_rounds: usize, num_query_rounds: usize,
) -> Result<()> { ) -> Result<()> {
type F = CrandallField; type F = CrandallField;
let n = degree; let n = 1 << degree_log;
let evals = PolynomialValues::new((0..n).map(|_| F::rand()).collect()); let evals = PolynomialValues::new((0..n).map(|_| F::rand()).collect());
let lde = evals.clone().lde(rate_bits); let lde = evals.clone().lde(rate_bits);
let config = FriConfig { let config = FriConfig {
reduction_count,
num_query_rounds, num_query_rounds,
rate_bits,
proof_of_work_bits: 2, proof_of_work_bits: 2,
reduction_arity_bits: Vec::new(), reduction_arity_bits,
}; };
let mut challenger = Challenger::new(); let mut challenger = Challenger::new();
let proof = fri_proof(&ifft(lde.clone()), &lde, &mut challenger, &config); let proof = fri_proof(&ifft(lde.clone()), &lde, &mut challenger, &config);
let mut challenger = Challenger::new(); let mut challenger = Challenger::new();
verify_fri_proof(&proof, &mut challenger, &config)?; verify_fri_proof(degree_log, &proof, &mut challenger, &config)?;
Ok(()) Ok(())
} }
fn gen_arities(degree_log: usize) -> Vec<usize> {
let mut rng = rand::thread_rng();
let mut arities = Vec::new();
let mut remaining = degree_log;
while remaining > 0 {
let arity = rng.gen_range(0, remaining + 1);
arities.push(arity);
remaining -= arity;
}
arities
}
#[test] #[test]
fn test_fri_multi_params() -> Result<()> { fn test_fri_multi_params() -> Result<()> {
for degree_log in 1..6 { for degree_log in 1..6 {
for rate_bits in 0..4 { for rate_bits in 0..4 {
for reduction_count in 1..=(degree_log + rate_bits) { for num_query_round in 0..4 {
for num_query_round in 0..4 { test_fri(
test_fri(1 << degree_log, rate_bits, reduction_count, num_query_round)?; degree_log,
} rate_bits,
gen_arities(degree_log),
num_query_round,
)?;
} }
} }
} }

View File

@ -1,8 +1,9 @@
use crate::circuit_builder::CircuitBuilder; use crate::circuit_builder::CircuitBuilder;
use crate::field::field::Field; use crate::field::field::Field;
use crate::gates::gmimc::GMiMCGate; use crate::gates::gmimc::GMiMCGate;
use crate::hash::GMIMC_ROUNDS;
use crate::hash::{compress, hash_or_noop}; use crate::hash::{compress, hash_or_noop};
use crate::hash::{merkle_root_inner, GMIMC_ROUNDS}; use crate::merkle_tree::MerkleTree;
use crate::proof::{Hash, HashTarget}; use crate::proof::{Hash, HashTarget};
use crate::target::Target; use crate::target::Target;
use crate::wire::Wire; use crate::wire::Wire;
@ -61,7 +62,7 @@ pub(crate) fn verify_merkle_proof_subtree<F: Field>(
} else { } else {
subtree_index subtree_index
}; };
let mut current_digest = merkle_root_inner(subtree_leaves_data); let mut current_digest = MerkleTree::new(subtree_leaves_data, false).root;
for (i, &sibling_digest) in proof.siblings.iter().enumerate() { for (i, &sibling_digest) in proof.siblings.iter().enumerate() {
let bit = (index >> i & 1) == 1; let bit = (index >> i & 1) == 1;
current_digest = if bit { current_digest = if bit {