Merge pull request #255 from mir-protocol/better_compressed_merkle_paths

Simpler Merkle paths compression
This commit is contained in:
wborgeaud 2021-09-21 08:56:34 +02:00 committed by GitHub
commit 5d8241760f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 463 additions and 19 deletions

View File

@ -1,18 +1,21 @@
use itertools::izip;
use serde::{Deserialize, Serialize};
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field_types::Field;
use crate::field::extension_field::{flatten, unflatten, Extendable};
use crate::field::field_types::{Field, RichField};
use crate::gadgets::polynomial::PolynomialCoeffsExtTarget;
use crate::hash::hash_types::MerkleCapTarget;
use crate::hash::merkle_proofs::{MerkleProof, MerkleProofTarget};
use crate::hash::merkle_tree::MerkleCap;
use crate::hash::path_compression::{compress_merkle_proofs, decompress_merkle_proofs};
use crate::iop::target::Target;
use crate::plonk::circuit_data::CommonCircuitData;
use crate::plonk::plonk_common::PolynomialsIndexBlinding;
use crate::polynomial::polynomial::PolynomialCoeffs;
/// Evaluations and Merkle proof produced by the prover in a FRI query step.
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
#[serde(bound = "")]
pub struct FriQueryStep<F: Extendable<D>, const D: usize> {
pub evals: Vec<F::Extension>,
@ -27,7 +30,7 @@ pub struct FriQueryStepTarget<const D: usize> {
/// Evaluations and Merkle proofs of the original set of polynomials,
/// before they are combined into a composition polynomial.
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
#[serde(bound = "")]
pub struct FriInitialTreeProof<F: Field> {
pub evals_proofs: Vec<(Vec<F>, MerkleProof<F>)>,
@ -61,9 +64,10 @@ impl FriInitialTreeProofTarget {
}
/// Proof for a FRI query round.
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
#[serde(bound = "")]
pub struct FriQueryRound<F: Extendable<D>, const D: usize> {
pub index: usize,
pub initial_trees_proof: FriInitialTreeProof<F>,
pub steps: Vec<FriQueryStep<F, D>>,
}
@ -74,7 +78,7 @@ pub struct FriQueryRoundTarget<const D: usize> {
pub steps: Vec<FriQueryStepTarget<D>>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
#[serde(bound = "")]
pub struct FriProof<F: Extendable<D>, const D: usize> {
/// A Merkle cap for each reduced polynomial in the commit phase.
@ -85,6 +89,8 @@ pub struct FriProof<F: Extendable<D>, const D: usize> {
pub final_poly: PolynomialCoeffs<F::Extension>,
/// Witness showing that the prover did PoW.
pub pow_witness: F,
/// Flag set to true if path compression has been applied to the proof's Merkle proofs.
pub is_compressed: bool,
}
pub struct FriProofTarget<const D: usize> {
@ -93,3 +99,187 @@ pub struct FriProofTarget<const D: usize> {
pub final_poly: PolynomialCoeffsExtTarget<D>,
pub pow_witness: Target,
}
impl<F: RichField + Extendable<D>, const D: usize> FriProof<F, D> {
/// Compress all the Merkle paths in the FRI proof.
pub fn compress(self, common_data: &CommonCircuitData<F, D>) -> Self {
if self.is_compressed {
panic!("Proof is already compressed.");
}
let FriProof {
commit_phase_merkle_caps,
mut query_round_proofs,
final_poly,
pow_witness,
..
} = self;
let cap_height = common_data.config.cap_height;
let reduction_arity_bits = &common_data.config.fri_config.reduction_arity_bits;
let num_reductions = reduction_arity_bits.len();
let num_initial_trees = query_round_proofs[0].initial_trees_proof.evals_proofs.len();
// "Transpose" the query round proofs, so that information for each Merkle tree is collected together.
let mut initial_trees_indices = vec![vec![]; num_initial_trees];
let mut initial_trees_leaves = vec![vec![]; num_initial_trees];
let mut initial_trees_proofs = vec![vec![]; num_initial_trees];
let mut steps_indices = vec![vec![]; num_reductions];
let mut steps_evals = vec![vec![]; num_reductions];
let mut steps_proofs = vec![vec![]; num_reductions];
for qrp in &query_round_proofs {
let FriQueryRound {
mut index,
initial_trees_proof,
steps,
} = qrp.clone();
for (i, (leaves_data, proof)) in
initial_trees_proof.evals_proofs.into_iter().enumerate()
{
initial_trees_indices[i].push(index);
initial_trees_leaves[i].push(leaves_data);
initial_trees_proofs[i].push(proof);
}
for (i, query_step) in steps.into_iter().enumerate() {
index >>= reduction_arity_bits[i];
steps_indices[i].push(index);
steps_evals[i].push(query_step.evals);
steps_proofs[i].push(query_step.merkle_proof);
}
}
// Compress all Merkle proofs.
let initial_trees_proofs = initial_trees_indices
.iter()
.zip(initial_trees_proofs)
.map(|(is, ps)| compress_merkle_proofs(cap_height, is, &ps))
.collect::<Vec<_>>();
let steps_proofs = steps_indices
.iter()
.zip(steps_proofs)
.map(|(is, ps)| compress_merkle_proofs(cap_height, is, &ps))
.collect::<Vec<_>>();
// Replace the query round proofs with the compressed versions.
for (i, qrp) in query_round_proofs.iter_mut().enumerate() {
qrp.initial_trees_proof = FriInitialTreeProof {
evals_proofs: (0..num_initial_trees)
.map(|j| {
(
initial_trees_leaves[j][i].clone(),
initial_trees_proofs[j][i].clone(),
)
})
.collect(),
};
qrp.steps = (0..num_reductions)
.map(|j| FriQueryStep {
evals: steps_evals[j][i].clone(),
merkle_proof: steps_proofs[j][i].clone(),
})
.collect();
}
FriProof {
commit_phase_merkle_caps,
query_round_proofs,
final_poly,
pow_witness,
is_compressed: true,
}
}
/// Decompress all the Merkle paths in the FRI proof.
pub fn decompress(self, common_data: &CommonCircuitData<F, D>) -> Self {
if !self.is_compressed {
panic!("Proof is not compressed.");
}
let FriProof {
commit_phase_merkle_caps,
mut query_round_proofs,
final_poly,
pow_witness,
..
} = self;
let cap_height = common_data.config.cap_height;
let reduction_arity_bits = &common_data.config.fri_config.reduction_arity_bits;
let num_reductions = reduction_arity_bits.len();
let num_initial_trees = query_round_proofs[0].initial_trees_proof.evals_proofs.len();
// "Transpose" the query round proofs, so that information for each Merkle tree is collected together.
let mut initial_trees_indices = vec![vec![]; num_initial_trees];
let mut initial_trees_leaves = vec![vec![]; num_initial_trees];
let mut initial_trees_proofs = vec![vec![]; num_initial_trees];
let mut steps_indices = vec![vec![]; num_reductions];
let mut steps_evals = vec![vec![]; num_reductions];
let mut steps_proofs = vec![vec![]; num_reductions];
let height = common_data.degree_bits + common_data.config.rate_bits;
let heights = reduction_arity_bits
.iter()
.scan(height, |acc, &bits| {
*acc -= bits;
Some(*acc)
})
.collect::<Vec<_>>();
for qrp in &query_round_proofs {
let FriQueryRound {
mut index,
initial_trees_proof,
steps,
} = qrp.clone();
for (i, (leaves_data, proof)) in
initial_trees_proof.evals_proofs.into_iter().enumerate()
{
initial_trees_indices[i].push(index);
initial_trees_leaves[i].push(leaves_data);
initial_trees_proofs[i].push(proof);
}
for (i, query_step) in steps.into_iter().enumerate() {
index >>= reduction_arity_bits[i];
steps_indices[i].push(index);
steps_evals[i].push(flatten(&query_step.evals));
steps_proofs[i].push(query_step.merkle_proof);
}
}
// Decompress all Merkle proofs.
let initial_trees_proofs = izip!(
&initial_trees_leaves,
&initial_trees_indices,
initial_trees_proofs
)
.map(|(ls, is, ps)| decompress_merkle_proofs(&ls, is, &ps, height, cap_height))
.collect::<Vec<_>>();
let steps_proofs = izip!(&steps_evals, &steps_indices, steps_proofs, heights)
.map(|(ls, is, ps, h)| decompress_merkle_proofs(ls, is, &ps, h, cap_height))
.collect::<Vec<_>>();
// Replace the query round proofs with the decompressed versions.
for (i, qrp) in query_round_proofs.iter_mut().enumerate() {
qrp.initial_trees_proof = FriInitialTreeProof {
evals_proofs: (0..num_initial_trees)
.map(|j| {
(
initial_trees_leaves[j][i].clone(),
initial_trees_proofs[j][i].clone(),
)
})
.collect(),
};
qrp.steps = (0..num_reductions)
.map(|j| FriQueryStep {
evals: unflatten(&steps_evals[j][i]),
merkle_proof: steps_proofs[j][i].clone(),
})
.collect();
}
FriProof {
commit_phase_merkle_caps,
query_round_proofs,
final_poly,
pow_witness,
is_compressed: false,
}
}
}

View File

@ -63,6 +63,7 @@ pub fn fri_proof<F: RichField + Extendable<D>, const D: usize>(
query_round_proofs,
final_poly: final_coeffs,
pow_witness,
is_compressed: false,
}
}
@ -152,7 +153,8 @@ fn fri_prover_query_round<F: RichField + Extendable<D>, const D: usize>(
) -> FriQueryRound<F, D> {
let mut query_steps = Vec::new();
let x = challenger.get_challenge();
let mut x_index = x.to_canonical_u64() as usize % n;
let initial_index = x.to_canonical_u64() as usize % n;
let mut x_index = initial_index;
let initial_proof = initial_merkle_trees
.iter()
.map(|t| (t.get(x_index).to_vec(), t.prove(x_index)))
@ -170,6 +172,7 @@ fn fri_prover_query_round<F: RichField + Extendable<D>, const D: usize>(
x_index >>= arity_bits;
}
FriQueryRound {
index: initial_index,
initial_trees_proof: FriInitialTreeProof {
evals_proofs: initial_proof,
},

View File

@ -259,6 +259,7 @@ fn fri_verifier_query_round<F: RichField + Extendable<D>, const D: usize>(
let config = &common_data.config.fri_config;
let x = challenger.get_challenge();
let mut x_index = x.to_canonical_u64() as usize % n;
ensure!(x_index == round_proof.index, "Wrong index.");
fri_verify_initial_proof(
x_index,
&round_proof.initial_trees_proof,

View File

@ -13,7 +13,7 @@ use crate::iop::target::{BoolTarget, Target};
use crate::iop::wire::Wire;
use crate::plonk::circuit_builder::CircuitBuilder;
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[serde(bound = "")]
pub struct MerkleProof<F: Field> {
/// The Merkle digest of each sibling subtree, staying from the bottommost layer.

View File

@ -8,11 +8,15 @@ use crate::hash::merkle_proofs::MerkleProof;
/// The Merkle cap of height `h` of a Merkle tree is the `h`-th layer (from the root) of the tree.
/// It can be used in place of the root to verify Merkle paths, which are `h` elements shorter.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
#[serde(bound = "")]
pub struct MerkleCap<F: Field>(pub Vec<HashOut<F>>);
impl<F: Field> MerkleCap<F> {
pub fn len(&self) -> usize {
self.0.len()
}
pub fn flatten(&self) -> Vec<F> {
self.0.iter().flat_map(|h| h.elements).collect()
}

View File

@ -3,6 +3,7 @@ pub mod hash_types;
pub mod hashing;
pub mod merkle_proofs;
pub mod merkle_tree;
pub mod path_compression;
pub mod poseidon;
pub mod rescue;

View File

@ -0,0 +1,149 @@
use std::collections::HashMap;
use anyhow::{ensure, Result};
use num::Integer;
use serde::{Deserialize, Serialize};
use crate::field::field_types::{Field, RichField};
use crate::hash::hash_types::HashOut;
use crate::hash::hashing::{compress, hash_or_noop};
use crate::hash::merkle_proofs::MerkleProof;
use crate::hash::merkle_tree::MerkleCap;
use crate::util::log2_strict;
/// Compress multiple Merkle proofs on the same tree by removing redundancy in the Merkle paths.
pub(crate) fn compress_merkle_proofs<F: Field>(
cap_height: usize,
indices: &[usize],
proofs: &[MerkleProof<F>],
) -> Vec<MerkleProof<F>> {
assert!(!proofs.is_empty());
let height = cap_height + proofs[0].siblings.len();
let num_leaves = 1 << height;
let mut compressed_proofs = Vec::with_capacity(proofs.len());
// Holds the known nodes in the tree at a given time. The root is at index 1.
// Valid indices are 1 through n, and each element at index `i` has
// children at indices `2i` and `2i +1` its parent at index `floor(i 2)`.
let mut known = vec![false; 2 * num_leaves];
for &i in indices {
// The leaves are known.
known[i + num_leaves] = true;
}
// For each proof collect all the unknown proof elements.
for (&i, p) in indices.iter().zip(proofs) {
let mut compressed_proof = MerkleProof {
siblings: Vec::new(),
};
let mut index = i + num_leaves;
for &sibling in &p.siblings {
let sibling_index = index ^ 1;
if !known[sibling_index] {
// If the sibling is not yet known, add it to the proof and set it to known.
compressed_proof.siblings.push(sibling);
known[sibling_index] = true;
}
// Go up the tree and set the parent to known.
index >>= 1;
known[index] = true;
}
compressed_proofs.push(compressed_proof);
}
compressed_proofs
}
/// Decompress compressed Merkle proofs.
/// Note: The data and indices must be in the same order as in `compress_merkle_proofs`.
pub(crate) fn decompress_merkle_proofs<F: RichField>(
leaves_data: &[Vec<F>],
leaves_indices: &[usize],
compressed_proofs: &[MerkleProof<F>],
height: usize,
cap_height: usize,
) -> Vec<MerkleProof<F>> {
let num_leaves = 1 << height;
let mut compressed_proofs = compressed_proofs.to_vec();
let mut decompressed_proofs = Vec::with_capacity(compressed_proofs.len());
// Holds the already seen nodes in the tree along with their value.
let mut seen = HashMap::new();
for (&i, v) in leaves_indices.iter().zip(leaves_data) {
// Observe the leaves.
seen.insert(i + num_leaves, hash_or_noop(v.to_vec()));
}
// For every index, go up the tree by querying `seen` to get node values, or if they are unknown
// get them from the compressed proof.
for (&i, p) in leaves_indices.iter().zip(compressed_proofs) {
let mut compressed_siblings = p.siblings.into_iter();
let mut decompressed_proof = MerkleProof {
siblings: Vec::new(),
};
let mut index = i + num_leaves;
let mut current_digest = seen[&index];
for _ in 0..height - cap_height {
let sibling_index = index ^ 1;
// Get the value of the sibling node by querying it or getting it from the proof.
let h = *seen
.entry(sibling_index)
.or_insert_with(|| compressed_siblings.next().unwrap());
decompressed_proof.siblings.push(h);
// Update the current digest to the value of the parent.
current_digest = if index.is_even() {
compress(current_digest, h)
} else {
compress(h, current_digest)
};
// Observe the parent.
index >>= 1;
seen.insert(index, current_digest);
}
decompressed_proofs.push(decompressed_proof);
}
decompressed_proofs
}
#[cfg(test)]
mod tests {
use rand::{thread_rng, Rng};
use super::*;
use crate::field::crandall_field::CrandallField;
use crate::field::field_types::Field;
use crate::hash::merkle_proofs::MerkleProof;
use crate::hash::merkle_tree::MerkleTree;
#[test]
fn test_path_compression() {
type F = CrandallField;
let h = 10;
let cap_height = 3;
let vs = (0..1 << h).map(|_| vec![F::rand()]).collect::<Vec<_>>();
let mt = MerkleTree::new(vs.clone(), cap_height);
let mut rng = thread_rng();
let k = rng.gen_range(1..=1 << h);
let indices = (0..k).map(|_| rng.gen_range(0..1 << h)).collect::<Vec<_>>();
let proofs = indices.iter().map(|&i| mt.prove(i)).collect::<Vec<_>>();
let compressed_proofs = compress_merkle_proofs(cap_height, &indices, &proofs);
let decompressed_proofs = decompress_merkle_proofs(
&indices.iter().map(|&i| vs[i].clone()).collect::<Vec<_>>(),
&indices,
&compressed_proofs,
h,
cap_height,
);
assert_eq!(proofs, decompressed_proofs);
let compressed_proof_bytes = serde_cbor::to_vec(&compressed_proofs).unwrap();
println!(
"Compressed proof length: {} bytes",
compressed_proof_bytes.len()
);
let proof_bytes = serde_cbor::to_vec(&proofs).unwrap();
println!("Proof length: {} bytes", proof_bytes.len());
}
}

View File

@ -11,7 +11,7 @@ use crate::hash::merkle_tree::MerkleCap;
use crate::iop::target::Target;
use crate::plonk::circuit_data::CommonCircuitData;
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
#[serde(bound = "")]
pub struct Proof<F: Extendable<D>, const D: usize> {
/// Merkle cap of LDEs of wire values.
@ -26,13 +26,6 @@ pub struct Proof<F: Extendable<D>, const D: usize> {
pub opening_proof: FriProof<F, D>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(bound = "")]
pub struct ProofWithPublicInputs<F: Extendable<D>, const D: usize> {
pub proof: Proof<F, D>,
pub public_inputs: Vec<F>,
}
pub struct ProofTarget<const D: usize> {
pub wires_cap: MerkleCapTarget,
pub plonk_zs_partial_products_cap: MerkleCapTarget,
@ -41,12 +34,57 @@ pub struct ProofTarget<const D: usize> {
pub opening_proof: FriProofTarget<D>,
}
impl<F: RichField + Extendable<D>, const D: usize> Proof<F, D> {
/// Returns `true` iff the opening proof is compressed.
pub fn is_compressed(&self) -> bool {
self.opening_proof.is_compressed
}
/// Compress the opening proof.
pub fn compress(mut self, common_data: &CommonCircuitData<F, D>) -> Self {
self.opening_proof = self.opening_proof.compress(common_data);
self
}
/// Decompress the opening proof.
pub fn decompress(mut self, common_data: &CommonCircuitData<F, D>) -> Self {
self.opening_proof = self.opening_proof.decompress(common_data);
self
}
}
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
#[serde(bound = "")]
pub struct ProofWithPublicInputs<F: Extendable<D>, const D: usize> {
pub proof: Proof<F, D>,
pub public_inputs: Vec<F>,
}
pub struct ProofWithPublicInputsTarget<const D: usize> {
pub proof: ProofTarget<D>,
pub public_inputs: Vec<Target>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
impl<F: RichField + Extendable<D>, const D: usize> ProofWithPublicInputs<F, D> {
/// Returns `true` iff the opening proof is compressed.
pub fn is_compressed(&self) -> bool {
self.proof.is_compressed()
}
/// Compress the opening proof.
pub fn compress(mut self, common_data: &CommonCircuitData<F, D>) -> Self {
self.proof = self.proof.compress(common_data);
self
}
/// Decompress the opening proof.
pub fn decompress(mut self, common_data: &CommonCircuitData<F, D>) -> Self {
self.proof = self.proof.decompress(common_data);
self
}
}
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
/// The purported values of each polynomial at a single point.
pub struct OpeningSet<F: Extendable<D>, const D: usize> {
pub constants: Vec<F::Extension>,
@ -102,3 +140,49 @@ pub struct OpeningSetTarget<const D: usize> {
pub partial_products: Vec<ExtensionTarget<D>>,
pub quotient_polys: Vec<ExtensionTarget<D>>,
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use crate::field::crandall_field::CrandallField;
use crate::field::extension_field::algebra::ExtensionAlgebra;
use crate::field::extension_field::quartic::QuarticExtension;
use crate::field::field_types::Field;
use crate::iop::witness::PartialWitness;
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::CircuitConfig;
use crate::plonk::verifier::verify;
#[test]
fn test_proof_compression() -> Result<()> {
type F = CrandallField;
type FF = QuarticExtension<CrandallField>;
const D: usize = 4;
let config = CircuitConfig::large_config();
let pw = PartialWitness::new();
let mut builder = CircuitBuilder::<F, D>::new(config);
// Build dummy circuit to get a valid proof.
let x = F::rand();
let y = F::rand();
let z = x * y;
let xt = builder.constant(x);
let yt = builder.constant(y);
let zt = builder.constant(z);
let comp_zt = builder.mul(xt, yt);
builder.connect(zt, comp_zt);
let data = builder.build();
let proof = data.prove(pw)?;
// Verify that `decompress ∘ compress = identity`.
let compressed_proof = proof.clone().compress(&data.common);
let decompressed_compressed_proof = compressed_proof.clone().decompress(&data.common);
assert_eq!(proof, decompressed_compressed_proof);
verify(proof, &data.verifier_only, &data.common)?;
verify(compressed_proof, &data.verifier_only, &data.common)
}
}

View File

@ -479,8 +479,16 @@ mod tests {
builder.print_gate_counts(0);
let data = builder.build();
let recursive_proof = data.prove(pw)?;
let now = std::time::Instant::now();
let compressed_recursive_proof = recursive_proof.clone().compress(&data.common);
info!("{:.4} to compress proof", now.elapsed().as_secs_f64());
let proof_bytes = serde_cbor::to_vec(&recursive_proof).unwrap();
info!("Proof length: {} bytes", proof_bytes.len());
let compressed_proof_bytes = serde_cbor::to_vec(&compressed_recursive_proof).unwrap();
info!(
"Compressed proof length: {} bytes",
compressed_proof_bytes.len()
);
verify(recursive_proof, &data.verifier_only, &data.common)
}
}

View File

@ -12,10 +12,14 @@ use crate::plonk::vanishing_poly::eval_vanishing_poly;
use crate::plonk::vars::EvaluationVars;
pub(crate) fn verify<F: RichField + Extendable<D>, const D: usize>(
proof_with_pis: ProofWithPublicInputs<F, D>,
mut proof_with_pis: ProofWithPublicInputs<F, D>,
verifier_data: &VerifierOnlyCircuitData<F>,
common_data: &CommonCircuitData<F, D>,
) -> Result<()> {
// Decompress the proof if needed.
if proof_with_pis.is_compressed() {
proof_with_pis = proof_with_pis.decompress(common_data);
}
let ProofWithPublicInputs {
proof,
public_inputs,