diff --git a/src/fri/proof.rs b/src/fri/proof.rs index 44cef3dc..b4ba8efb 100644 --- a/src/fri/proof.rs +++ b/src/fri/proof.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use itertools::izip; use serde::{Deserialize, Serialize}; @@ -85,7 +85,7 @@ pub struct FriQueryRoundTarget { pub struct CompressedFriQueryRounds, const D: usize> { /// Map from initial indices `i` to the `FriInitialProof` for the `i`th leaf. pub initial_trees_proofs: HashMap>, - /// For each FRI query step, a map from initial indices `i` to the `FriInitialProof` for the `i`th leaf. + /// For each FRI query step, a map from indices `i` to the `FriQueryStep` for the `i`th leaf. pub steps: Vec>>, } @@ -200,14 +200,17 @@ impl, const D: usize> FriProof { }; compressed_query_proofs .initial_trees_proofs - .insert(index, initial_proof); + .entry(index) + .or_insert(initial_proof); for j in 0..num_reductions { index >>= reduction_arity_bits[j]; let query_step = FriQueryStep { evals: steps_evals[j][i].clone(), merkle_proof: steps_proofs[j][i].clone(), }; - compressed_query_proofs.steps[j].insert(index, query_step); + compressed_query_proofs.steps[j] + .entry(index) + .or_insert(query_step); } } @@ -261,8 +264,17 @@ impl, const D: usize> CompressedFriProof { }) .collect::>(); + let mut seen_indices = vec![HashSet::new(); num_reductions + 1]; for mut index in indices.iter().copied() { - let initial_trees_proof = query_round_proofs.initial_trees_proofs[&index].clone(); + let mut initial_trees_proof = query_round_proofs.initial_trees_proofs[&index].clone(); + if !seen_indices[0].insert(index) { + initial_trees_proof + .evals_proofs + .iter_mut() + .for_each(|(_, p)| { + p.siblings = vec![]; + }); + } for (i, (leaves_data, proof)) in initial_trees_proof.evals_proofs.into_iter().enumerate() { @@ -272,7 +284,10 @@ impl, const D: usize> CompressedFriProof { } for i in 0..num_reductions { index >>= reduction_arity_bits[i]; - let query_step = query_round_proofs.steps[i][&index].clone(); + let mut query_step = query_round_proofs.steps[i][&index].clone(); + if !seen_indices[1 + i].insert(index) { + query_step.merkle_proof.siblings = vec![]; + } steps_indices[i].push(index); steps_evals[i].push(flatten(&query_step.evals)); steps_proofs[i].push(query_step.merkle_proof); @@ -292,7 +307,7 @@ impl, const D: usize> CompressedFriProof { .collect::>(); let mut decompressed_query_proofs = Vec::with_capacity(num_reductions); - for i in 0..num_reductions { + for i in 0..indices.len() { let initial_trees_proof = FriInitialTreeProof { evals_proofs: (0..num_initial_trees) .map(|j| { diff --git a/src/plonk/proof.rs b/src/plonk/proof.rs index c7d0b0ff..9510687a 100644 --- a/src/plonk/proof.rs +++ b/src/plonk/proof.rs @@ -252,7 +252,8 @@ mod tests { type F = CrandallField; const D: usize = 4; - let config = CircuitConfig::large_config(); + let mut config = CircuitConfig::large_config(); + config.fri_config.num_query_rounds = 50; let pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); diff --git a/src/plonk/recursive_verifier.rs b/src/plonk/recursive_verifier.rs index 75158d67..15dd7830 100644 --- a/src/plonk/recursive_verifier.rs +++ b/src/plonk/recursive_verifier.rs @@ -482,6 +482,10 @@ mod tests { let recursive_proof = data.prove(pw)?; let now = std::time::Instant::now(); let compressed_recursive_proof = recursive_proof.clone().compress(&data.common)?; + let decompressed_compressed_proof = compressed_recursive_proof + .clone() + .decompress(&data.common)?; + assert_eq!(recursive_proof, decompressed_compressed_proof); info!("{:.4} to compress proof", now.elapsed().as_secs_f64()); let proof_bytes = serde_cbor::to_vec(&recursive_proof).unwrap(); info!("Proof length: {} bytes", proof_bytes.len());