diff --git a/src/fri/proof.rs b/src/fri/proof.rs index 4ee5bafb..eef39ad6 100644 --- a/src/fri/proof.rs +++ b/src/fri/proof.rs @@ -1,18 +1,21 @@ +use itertools::izip; use serde::{Deserialize, Serialize}; use crate::field::extension_field::target::ExtensionTarget; -use crate::field::extension_field::Extendable; -use crate::field::field_types::Field; +use crate::field::extension_field::{flatten, unflatten, Extendable}; +use crate::field::field_types::{Field, RichField}; use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; use crate::hash::hash_types::MerkleCapTarget; use crate::hash::merkle_proofs::{MerkleProof, MerkleProofTarget}; use crate::hash::merkle_tree::MerkleCap; +use crate::hash::path_compression::{compress_merkle_proofs, decompress_merkle_proofs}; use crate::iop::target::Target; +use crate::plonk::circuit_data::CommonCircuitData; use crate::plonk::plonk_common::PolynomialsIndexBlinding; use crate::polynomial::polynomial::PolynomialCoeffs; /// Evaluations and Merkle proof produced by the prover in a FRI query step. -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] #[serde(bound = "")] pub struct FriQueryStep, const D: usize> { pub evals: Vec, @@ -27,7 +30,7 @@ pub struct FriQueryStepTarget { /// Evaluations and Merkle proofs of the original set of polynomials, /// before they are combined into a composition polynomial. -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] #[serde(bound = "")] pub struct FriInitialTreeProof { pub evals_proofs: Vec<(Vec, MerkleProof)>, @@ -61,9 +64,10 @@ impl FriInitialTreeProofTarget { } /// Proof for a FRI query round. -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] #[serde(bound = "")] pub struct FriQueryRound, const D: usize> { + pub index: usize, pub initial_trees_proof: FriInitialTreeProof, pub steps: Vec>, } @@ -74,7 +78,7 @@ pub struct FriQueryRoundTarget { pub steps: Vec>, } -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] #[serde(bound = "")] pub struct FriProof, const D: usize> { /// A Merkle cap for each reduced polynomial in the commit phase. @@ -85,6 +89,8 @@ pub struct FriProof, const D: usize> { pub final_poly: PolynomialCoeffs, /// Witness showing that the prover did PoW. pub pow_witness: F, + /// Flag set to true if path compression has been applied to the proof's Merkle proofs. + pub is_compressed: bool, } pub struct FriProofTarget { @@ -93,3 +99,187 @@ pub struct FriProofTarget { pub final_poly: PolynomialCoeffsExtTarget, pub pow_witness: Target, } + +impl, const D: usize> FriProof { + /// Compress all the Merkle paths in the FRI proof. + pub fn compress(self, common_data: &CommonCircuitData) -> Self { + if self.is_compressed { + panic!("Proof is already compressed."); + } + let FriProof { + commit_phase_merkle_caps, + mut query_round_proofs, + final_poly, + pow_witness, + .. + } = self; + let cap_height = common_data.config.cap_height; + let reduction_arity_bits = &common_data.config.fri_config.reduction_arity_bits; + let num_reductions = reduction_arity_bits.len(); + let num_initial_trees = query_round_proofs[0].initial_trees_proof.evals_proofs.len(); + + // "Transpose" the query round proofs, so that information for each Merkle tree is collected together. + let mut initial_trees_indices = vec![vec![]; num_initial_trees]; + let mut initial_trees_leaves = vec![vec![]; num_initial_trees]; + let mut initial_trees_proofs = vec![vec![]; num_initial_trees]; + let mut steps_indices = vec![vec![]; num_reductions]; + let mut steps_evals = vec![vec![]; num_reductions]; + let mut steps_proofs = vec![vec![]; num_reductions]; + + for qrp in &query_round_proofs { + let FriQueryRound { + mut index, + initial_trees_proof, + steps, + } = qrp.clone(); + for (i, (leaves_data, proof)) in + initial_trees_proof.evals_proofs.into_iter().enumerate() + { + initial_trees_indices[i].push(index); + initial_trees_leaves[i].push(leaves_data); + initial_trees_proofs[i].push(proof); + } + for (i, query_step) in steps.into_iter().enumerate() { + index >>= reduction_arity_bits[i]; + steps_indices[i].push(index); + steps_evals[i].push(query_step.evals); + steps_proofs[i].push(query_step.merkle_proof); + } + } + + // Compress all Merkle proofs. + let initial_trees_proofs = initial_trees_indices + .iter() + .zip(initial_trees_proofs) + .map(|(is, ps)| compress_merkle_proofs(cap_height, is, &ps)) + .collect::>(); + let steps_proofs = steps_indices + .iter() + .zip(steps_proofs) + .map(|(is, ps)| compress_merkle_proofs(cap_height, is, &ps)) + .collect::>(); + + // Replace the query round proofs with the compressed versions. + for (i, qrp) in query_round_proofs.iter_mut().enumerate() { + qrp.initial_trees_proof = FriInitialTreeProof { + evals_proofs: (0..num_initial_trees) + .map(|j| { + ( + initial_trees_leaves[j][i].clone(), + initial_trees_proofs[j][i].clone(), + ) + }) + .collect(), + }; + qrp.steps = (0..num_reductions) + .map(|j| FriQueryStep { + evals: steps_evals[j][i].clone(), + merkle_proof: steps_proofs[j][i].clone(), + }) + .collect(); + } + + FriProof { + commit_phase_merkle_caps, + query_round_proofs, + final_poly, + pow_witness, + is_compressed: true, + } + } + + /// Decompress all the Merkle paths in the FRI proof. + pub fn decompress(self, common_data: &CommonCircuitData) -> Self { + if !self.is_compressed { + panic!("Proof is not compressed."); + } + let FriProof { + commit_phase_merkle_caps, + mut query_round_proofs, + final_poly, + pow_witness, + .. + } = self; + let cap_height = common_data.config.cap_height; + let reduction_arity_bits = &common_data.config.fri_config.reduction_arity_bits; + let num_reductions = reduction_arity_bits.len(); + let num_initial_trees = query_round_proofs[0].initial_trees_proof.evals_proofs.len(); + + // "Transpose" the query round proofs, so that information for each Merkle tree is collected together. + let mut initial_trees_indices = vec![vec![]; num_initial_trees]; + let mut initial_trees_leaves = vec![vec![]; num_initial_trees]; + let mut initial_trees_proofs = vec![vec![]; num_initial_trees]; + let mut steps_indices = vec![vec![]; num_reductions]; + let mut steps_evals = vec![vec![]; num_reductions]; + let mut steps_proofs = vec![vec![]; num_reductions]; + let height = common_data.degree_bits + common_data.config.rate_bits; + let heights = reduction_arity_bits + .iter() + .scan(height, |acc, &bits| { + *acc -= bits; + Some(*acc) + }) + .collect::>(); + + for qrp in &query_round_proofs { + let FriQueryRound { + mut index, + initial_trees_proof, + steps, + } = qrp.clone(); + for (i, (leaves_data, proof)) in + initial_trees_proof.evals_proofs.into_iter().enumerate() + { + initial_trees_indices[i].push(index); + initial_trees_leaves[i].push(leaves_data); + initial_trees_proofs[i].push(proof); + } + for (i, query_step) in steps.into_iter().enumerate() { + index >>= reduction_arity_bits[i]; + steps_indices[i].push(index); + steps_evals[i].push(flatten(&query_step.evals)); + steps_proofs[i].push(query_step.merkle_proof); + } + } + + // Decompress all Merkle proofs. + let initial_trees_proofs = izip!( + &initial_trees_leaves, + &initial_trees_indices, + initial_trees_proofs + ) + .map(|(ls, is, ps)| decompress_merkle_proofs(&ls, is, &ps, height, cap_height)) + .collect::>(); + let steps_proofs = izip!(&steps_evals, &steps_indices, steps_proofs, heights) + .map(|(ls, is, ps, h)| decompress_merkle_proofs(ls, is, &ps, h, cap_height)) + .collect::>(); + + // Replace the query round proofs with the decompressed versions. + for (i, qrp) in query_round_proofs.iter_mut().enumerate() { + qrp.initial_trees_proof = FriInitialTreeProof { + evals_proofs: (0..num_initial_trees) + .map(|j| { + ( + initial_trees_leaves[j][i].clone(), + initial_trees_proofs[j][i].clone(), + ) + }) + .collect(), + }; + qrp.steps = (0..num_reductions) + .map(|j| FriQueryStep { + evals: unflatten(&steps_evals[j][i]), + merkle_proof: steps_proofs[j][i].clone(), + }) + .collect(); + } + + FriProof { + commit_phase_merkle_caps, + query_round_proofs, + final_poly, + pow_witness, + is_compressed: false, + } + } +} diff --git a/src/fri/prover.rs b/src/fri/prover.rs index a0d71d98..6bc0562a 100644 --- a/src/fri/prover.rs +++ b/src/fri/prover.rs @@ -63,6 +63,7 @@ pub fn fri_proof, const D: usize>( query_round_proofs, final_poly: final_coeffs, pow_witness, + is_compressed: false, } } @@ -152,7 +153,8 @@ fn fri_prover_query_round, const D: usize>( ) -> FriQueryRound { let mut query_steps = Vec::new(); let x = challenger.get_challenge(); - let mut x_index = x.to_canonical_u64() as usize % n; + let initial_index = x.to_canonical_u64() as usize % n; + let mut x_index = initial_index; let initial_proof = initial_merkle_trees .iter() .map(|t| (t.get(x_index).to_vec(), t.prove(x_index))) @@ -170,6 +172,7 @@ fn fri_prover_query_round, const D: usize>( x_index >>= arity_bits; } FriQueryRound { + index: initial_index, initial_trees_proof: FriInitialTreeProof { evals_proofs: initial_proof, }, diff --git a/src/fri/verifier.rs b/src/fri/verifier.rs index e14e586f..20579da7 100644 --- a/src/fri/verifier.rs +++ b/src/fri/verifier.rs @@ -259,6 +259,7 @@ fn fri_verifier_query_round, const D: usize>( let config = &common_data.config.fri_config; let x = challenger.get_challenge(); let mut x_index = x.to_canonical_u64() as usize % n; + ensure!(x_index == round_proof.index, "Wrong index."); fri_verify_initial_proof( x_index, &round_proof.initial_trees_proof, diff --git a/src/hash/merkle_proofs.rs b/src/hash/merkle_proofs.rs index 7a176dd9..4b7fdf63 100644 --- a/src/hash/merkle_proofs.rs +++ b/src/hash/merkle_proofs.rs @@ -13,7 +13,7 @@ use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; use crate::plonk::circuit_builder::CircuitBuilder; -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] #[serde(bound = "")] pub struct MerkleProof { /// The Merkle digest of each sibling subtree, staying from the bottommost layer. diff --git a/src/hash/merkle_tree.rs b/src/hash/merkle_tree.rs index 1bc28d5a..9b9ffe7e 100644 --- a/src/hash/merkle_tree.rs +++ b/src/hash/merkle_tree.rs @@ -8,11 +8,15 @@ use crate::hash::merkle_proofs::MerkleProof; /// The Merkle cap of height `h` of a Merkle tree is the `h`-th layer (from the root) of the tree. /// It can be used in place of the root to verify Merkle paths, which are `h` elements shorter. -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] #[serde(bound = "")] pub struct MerkleCap(pub Vec>); impl MerkleCap { + pub fn len(&self) -> usize { + self.0.len() + } + pub fn flatten(&self) -> Vec { self.0.iter().flat_map(|h| h.elements).collect() } diff --git a/src/hash/mod.rs b/src/hash/mod.rs index 7ba7c42c..e33023ec 100644 --- a/src/hash/mod.rs +++ b/src/hash/mod.rs @@ -3,6 +3,7 @@ pub mod hash_types; pub mod hashing; pub mod merkle_proofs; pub mod merkle_tree; +pub mod path_compression; pub mod poseidon; pub mod rescue; diff --git a/src/hash/path_compression.rs b/src/hash/path_compression.rs new file mode 100644 index 00000000..b492e69b --- /dev/null +++ b/src/hash/path_compression.rs @@ -0,0 +1,149 @@ +use std::collections::HashMap; + +use anyhow::{ensure, Result}; +use num::Integer; +use serde::{Deserialize, Serialize}; + +use crate::field::field_types::{Field, RichField}; +use crate::hash::hash_types::HashOut; +use crate::hash::hashing::{compress, hash_or_noop}; +use crate::hash::merkle_proofs::MerkleProof; +use crate::hash::merkle_tree::MerkleCap; +use crate::util::log2_strict; + +/// Compress multiple Merkle proofs on the same tree by removing redundancy in the Merkle paths. +pub(crate) fn compress_merkle_proofs( + cap_height: usize, + indices: &[usize], + proofs: &[MerkleProof], +) -> Vec> { + assert!(!proofs.is_empty()); + let height = cap_height + proofs[0].siblings.len(); + let num_leaves = 1 << height; + let mut compressed_proofs = Vec::with_capacity(proofs.len()); + // Holds the known nodes in the tree at a given time. The root is at index 1. + // Valid indices are 1 through n, and each element at index `i` has + // children at indices `2i` and `2i +1` its parent at index `floor(i ∕ 2)`. + let mut known = vec![false; 2 * num_leaves]; + for &i in indices { + // The leaves are known. + known[i + num_leaves] = true; + } + // For each proof collect all the unknown proof elements. + for (&i, p) in indices.iter().zip(proofs) { + let mut compressed_proof = MerkleProof { + siblings: Vec::new(), + }; + let mut index = i + num_leaves; + for &sibling in &p.siblings { + let sibling_index = index ^ 1; + if !known[sibling_index] { + // If the sibling is not yet known, add it to the proof and set it to known. + compressed_proof.siblings.push(sibling); + known[sibling_index] = true; + } + // Go up the tree and set the parent to known. + index >>= 1; + known[index] = true; + } + compressed_proofs.push(compressed_proof); + } + + compressed_proofs +} + +/// Decompress compressed Merkle proofs. +/// Note: The data and indices must be in the same order as in `compress_merkle_proofs`. +pub(crate) fn decompress_merkle_proofs( + leaves_data: &[Vec], + leaves_indices: &[usize], + compressed_proofs: &[MerkleProof], + height: usize, + cap_height: usize, +) -> Vec> { + let num_leaves = 1 << height; + let mut compressed_proofs = compressed_proofs.to_vec(); + let mut decompressed_proofs = Vec::with_capacity(compressed_proofs.len()); + // Holds the already seen nodes in the tree along with their value. + let mut seen = HashMap::new(); + + for (&i, v) in leaves_indices.iter().zip(leaves_data) { + // Observe the leaves. + seen.insert(i + num_leaves, hash_or_noop(v.to_vec())); + } + // For every index, go up the tree by querying `seen` to get node values, or if they are unknown + // get them from the compressed proof. + for (&i, p) in leaves_indices.iter().zip(compressed_proofs) { + let mut compressed_siblings = p.siblings.into_iter(); + let mut decompressed_proof = MerkleProof { + siblings: Vec::new(), + }; + let mut index = i + num_leaves; + let mut current_digest = seen[&index]; + for _ in 0..height - cap_height { + let sibling_index = index ^ 1; + // Get the value of the sibling node by querying it or getting it from the proof. + let h = *seen + .entry(sibling_index) + .or_insert_with(|| compressed_siblings.next().unwrap()); + decompressed_proof.siblings.push(h); + // Update the current digest to the value of the parent. + current_digest = if index.is_even() { + compress(current_digest, h) + } else { + compress(h, current_digest) + }; + // Observe the parent. + index >>= 1; + seen.insert(index, current_digest); + } + + decompressed_proofs.push(decompressed_proof); + } + + decompressed_proofs +} + +#[cfg(test)] +mod tests { + use rand::{thread_rng, Rng}; + + use super::*; + use crate::field::crandall_field::CrandallField; + use crate::field::field_types::Field; + use crate::hash::merkle_proofs::MerkleProof; + use crate::hash::merkle_tree::MerkleTree; + + #[test] + fn test_path_compression() { + type F = CrandallField; + let h = 10; + let cap_height = 3; + let vs = (0..1 << h).map(|_| vec![F::rand()]).collect::>(); + let mt = MerkleTree::new(vs.clone(), cap_height); + + let mut rng = thread_rng(); + let k = rng.gen_range(1..=1 << h); + let indices = (0..k).map(|_| rng.gen_range(0..1 << h)).collect::>(); + let proofs = indices.iter().map(|&i| mt.prove(i)).collect::>(); + + let compressed_proofs = compress_merkle_proofs(cap_height, &indices, &proofs); + let decompressed_proofs = decompress_merkle_proofs( + &indices.iter().map(|&i| vs[i].clone()).collect::>(), + &indices, + &compressed_proofs, + h, + cap_height, + ); + + assert_eq!(proofs, decompressed_proofs); + + let compressed_proof_bytes = serde_cbor::to_vec(&compressed_proofs).unwrap(); + println!( + "Compressed proof length: {} bytes", + compressed_proof_bytes.len() + ); + let proof_bytes = serde_cbor::to_vec(&proofs).unwrap(); + println!("Proof length: {} bytes", proof_bytes.len()); + } +} diff --git a/src/plonk/proof.rs b/src/plonk/proof.rs index 68871202..cdaa28ed 100644 --- a/src/plonk/proof.rs +++ b/src/plonk/proof.rs @@ -11,7 +11,7 @@ use crate::hash::merkle_tree::MerkleCap; use crate::iop::target::Target; use crate::plonk::circuit_data::CommonCircuitData; -#[derive(Serialize, Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] #[serde(bound = "")] pub struct Proof, const D: usize> { /// Merkle cap of LDEs of wire values. @@ -26,13 +26,6 @@ pub struct Proof, const D: usize> { pub opening_proof: FriProof, } -#[derive(Serialize, Deserialize, Clone, Debug)] -#[serde(bound = "")] -pub struct ProofWithPublicInputs, const D: usize> { - pub proof: Proof, - pub public_inputs: Vec, -} - pub struct ProofTarget { pub wires_cap: MerkleCapTarget, pub plonk_zs_partial_products_cap: MerkleCapTarget, @@ -41,12 +34,57 @@ pub struct ProofTarget { pub opening_proof: FriProofTarget, } +impl, const D: usize> Proof { + /// Returns `true` iff the opening proof is compressed. + pub fn is_compressed(&self) -> bool { + self.opening_proof.is_compressed + } + + /// Compress the opening proof. + pub fn compress(mut self, common_data: &CommonCircuitData) -> Self { + self.opening_proof = self.opening_proof.compress(common_data); + self + } + + /// Decompress the opening proof. + pub fn decompress(mut self, common_data: &CommonCircuitData) -> Self { + self.opening_proof = self.opening_proof.decompress(common_data); + self + } +} + +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] +#[serde(bound = "")] +pub struct ProofWithPublicInputs, const D: usize> { + pub proof: Proof, + pub public_inputs: Vec, +} + pub struct ProofWithPublicInputsTarget { pub proof: ProofTarget, pub public_inputs: Vec, } -#[derive(Clone, Debug, Serialize, Deserialize)] +impl, const D: usize> ProofWithPublicInputs { + /// Returns `true` iff the opening proof is compressed. + pub fn is_compressed(&self) -> bool { + self.proof.is_compressed() + } + + /// Compress the opening proof. + pub fn compress(mut self, common_data: &CommonCircuitData) -> Self { + self.proof = self.proof.compress(common_data); + self + } + + /// Decompress the opening proof. + pub fn decompress(mut self, common_data: &CommonCircuitData) -> Self { + self.proof = self.proof.decompress(common_data); + self + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] /// The purported values of each polynomial at a single point. pub struct OpeningSet, const D: usize> { pub constants: Vec, @@ -102,3 +140,49 @@ pub struct OpeningSetTarget { pub partial_products: Vec>, pub quotient_polys: Vec>, } + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::algebra::ExtensionAlgebra; + use crate::field::extension_field::quartic::QuarticExtension; + use crate::field::field_types::Field; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::verifier::verify; + + #[test] + fn test_proof_compression() -> Result<()> { + type F = CrandallField; + type FF = QuarticExtension; + const D: usize = 4; + + let config = CircuitConfig::large_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + // Build dummy circuit to get a valid proof. + let x = F::rand(); + let y = F::rand(); + let z = x * y; + let xt = builder.constant(x); + let yt = builder.constant(y); + let zt = builder.constant(z); + let comp_zt = builder.mul(xt, yt); + builder.connect(zt, comp_zt); + let data = builder.build(); + let proof = data.prove(pw)?; + + // Verify that `decompress ∘ compress = identity`. + let compressed_proof = proof.clone().compress(&data.common); + let decompressed_compressed_proof = compressed_proof.clone().decompress(&data.common); + assert_eq!(proof, decompressed_compressed_proof); + + verify(proof, &data.verifier_only, &data.common)?; + verify(compressed_proof, &data.verifier_only, &data.common) + } +} diff --git a/src/plonk/recursive_verifier.rs b/src/plonk/recursive_verifier.rs index 0b371b9d..98d5d7cc 100644 --- a/src/plonk/recursive_verifier.rs +++ b/src/plonk/recursive_verifier.rs @@ -479,8 +479,16 @@ mod tests { builder.print_gate_counts(0); let data = builder.build(); let recursive_proof = data.prove(pw)?; + let now = std::time::Instant::now(); + let compressed_recursive_proof = recursive_proof.clone().compress(&data.common); + info!("{:.4} to compress proof", now.elapsed().as_secs_f64()); let proof_bytes = serde_cbor::to_vec(&recursive_proof).unwrap(); info!("Proof length: {} bytes", proof_bytes.len()); + let compressed_proof_bytes = serde_cbor::to_vec(&compressed_recursive_proof).unwrap(); + info!( + "Compressed proof length: {} bytes", + compressed_proof_bytes.len() + ); verify(recursive_proof, &data.verifier_only, &data.common) } } diff --git a/src/plonk/verifier.rs b/src/plonk/verifier.rs index 96c43d52..217e5cb4 100644 --- a/src/plonk/verifier.rs +++ b/src/plonk/verifier.rs @@ -12,10 +12,14 @@ use crate::plonk::vanishing_poly::eval_vanishing_poly; use crate::plonk::vars::EvaluationVars; pub(crate) fn verify, const D: usize>( - proof_with_pis: ProofWithPublicInputs, + mut proof_with_pis: ProofWithPublicInputs, verifier_data: &VerifierOnlyCircuitData, common_data: &CommonCircuitData, ) -> Result<()> { + // Decompress the proof if needed. + if proof_with_pis.is_compressed() { + proof_with_pis = proof_with_pis.decompress(common_data); + } let ProofWithPublicInputs { proof, public_inputs,