From 4aaf57e9a992844842dd157d8852ff495d27dd3d Mon Sep 17 00:00:00 2001 From: "Brandon H. Gomes" Date: Wed, 2 Nov 2022 21:07:51 -0700 Subject: [PATCH] feat: separate reading and writing to get infallible writers Signed-off-by: Brandon H. Gomes --- plonky2/Cargo.toml | 20 +- plonky2/src/fri/proof.rs | 2 +- plonky2/src/gates/gate.rs | 2 +- plonky2/src/hash/keccak.rs | 7 +- plonky2/src/hash/path_compression.rs | 2 +- plonky2/src/iop/witness.rs | 2 +- plonky2/src/plonk/circuit_builder.rs | 5 +- plonky2/src/plonk/get_challenges.rs | 2 +- plonky2/src/plonk/permutation_argument.rs | 2 +- plonky2/src/plonk/proof.rs | 22 +- plonky2/src/util/context_tree.rs | 2 +- plonky2/src/util/mod.rs | 5 +- plonky2/src/util/serialization.rs | 897 ++++++++++++++-------- util/src/lib.rs | 10 + 14 files changed, 612 insertions(+), 368 deletions(-) diff --git a/plonky2/Cargo.toml b/plonky2/Cargo.toml index 6de787d2..b74cc151 100644 --- a/plonky2/Cargo.toml +++ b/plonky2/Cargo.toml @@ -11,27 +11,17 @@ edition = "2021" default-run = "generate_constants" [features] -default = [ - "gate_testing", - "parallel", - "rand", - "rand_chacha", - "std", - "timing", -] -rand = [ - "dep:rand", - "num/rand", - "plonky2_field/rand" -] +default = ["gate_testing", "parallel", "rand", "rand_chacha", "std", "timing"] +rand = ["dep:rand", "num/rand", "plonky2_field/rand"] gate_testing = ["rand"] -parallel = ["maybe_rayon/parallel"] -std = ["anyhow/std"] +parallel = ["hashbrown/rayon", "maybe_rayon/parallel"] +std = ["anyhow/std", "rand/std"] timing = [] [dependencies] anyhow = { version = "1.0.40", default-features = false } derivative = { version = "2.2.0", default-features = false, features = ["use_core"] } +hashbrown = { version = "0.12.3", default-features = false, features = ["ahash", "serde"] } itertools = { version = "0.10.0", default-features = false } keccak-hash = { version = "0.8.0", default-features = false } log = { version = "0.4.14", default-features = false } diff --git a/plonky2/src/fri/proof.rs b/plonky2/src/fri/proof.rs index 2404a7bd..c3d3ecbb 100644 --- a/plonky2/src/fri/proof.rs +++ b/plonky2/src/fri/proof.rs @@ -1,7 +1,7 @@ use alloc::vec; use alloc::vec::Vec; -use std::collections::HashMap; +use hashbrown::HashMap; use itertools::izip; use plonky2_field::extension::{flatten, unflatten, Extendable}; use plonky2_field::polynomial::PolynomialCoeffs; diff --git a/plonky2/src/gates/gate.rs b/plonky2/src/gates/gate.rs index 648bf779..7b2db413 100644 --- a/plonky2/src/gates/gate.rs +++ b/plonky2/src/gates/gate.rs @@ -6,8 +6,8 @@ use alloc::vec::Vec; use core::fmt::{Debug, Error, Formatter}; use core::hash::{Hash, Hasher}; use core::ops::Range; -use std::collections::HashMap; +use hashbrown::HashMap; use plonky2_field::batch_util::batch_multiply_inplace; use plonky2_field::extension::{Extendable, FieldExtension}; use plonky2_field::types::Field; diff --git a/plonky2/src/hash/keccak.rs b/plonky2/src/hash/keccak.rs index 1a4f5472..0efa154c 100644 --- a/plonky2/src/hash/keccak.rs +++ b/plonky2/src/hash/keccak.rs @@ -9,7 +9,7 @@ use keccak_hash::keccak; use crate::hash::hash_types::{BytesHash, RichField}; use crate::hash::hashing::{PlonkyPermutation, SPONGE_WIDTH}; use crate::plonk::config::Hasher; -use crate::util::serialization::Buffer; +use crate::util::serialization::Write; /// Keccak-256 pseudo-permutation (not necessarily one-to-one) used in the challenger. /// A state `input: [F; 12]` is sent to the field representation of `H(input) || H(H(input)) || H(H(H(input)))` @@ -53,16 +53,17 @@ impl PlonkyPermutation for KeccakPermutation { /// Keccak-256 hash function. #[derive(Copy, Clone, Debug, Eq, PartialEq)] pub struct KeccakHash; + impl Hasher for KeccakHash { const HASH_SIZE: usize = N; type Hash = BytesHash; type Permutation = KeccakPermutation; fn hash_no_pad(input: &[F]) -> Self::Hash { - let mut buffer = Buffer::new(Vec::new()); + let mut buffer = Vec::new(); buffer.write_field_vec(input).unwrap(); let mut arr = [0; N]; - let hash_bytes = keccak(buffer.bytes()).0; + let hash_bytes = keccak(buffer).0; arr.copy_from_slice(&hash_bytes[..N]); BytesHash(arr) } diff --git a/plonky2/src/hash/path_compression.rs b/plonky2/src/hash/path_compression.rs index ed3f49b8..ed93b25d 100644 --- a/plonky2/src/hash/path_compression.rs +++ b/plonky2/src/hash/path_compression.rs @@ -1,7 +1,7 @@ use alloc::vec; use alloc::vec::Vec; -use std::collections::HashMap; +use hashbrown::HashMap; use num::Integer; use crate::hash::hash_types::RichField; diff --git a/plonky2/src/iop/witness.rs b/plonky2/src/iop/witness.rs index ef0ff7ae..a7cdf1f4 100644 --- a/plonky2/src/iop/witness.rs +++ b/plonky2/src/iop/witness.rs @@ -1,7 +1,7 @@ use alloc::vec; use alloc::vec::Vec; -use std::collections::HashMap; +use hashbrown::HashMap; use itertools::Itertools; use plonky2_field::extension::{Extendable, FieldExtension}; use plonky2_field::types::Field; diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index 06b11c05..deb46442 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -3,9 +3,10 @@ use alloc::collections::BTreeMap; use alloc::vec; use alloc::vec::Vec; use core::cmp::max; -use std::collections::{HashMap, HashSet}; +#[cfg(feature = "std")] use std::time::Instant; +use hashbrown::{HashMap, HashSet}; use itertools::Itertools; use log::{debug, info, Level}; use plonky2_field::cosets::get_unique_coset_shifts; @@ -692,6 +693,7 @@ impl, const D: usize> CircuitBuilder { /// Builds a "full circuit", with both prover and verifier data. pub fn build>(mut self) -> CircuitData { let mut timing = TimingTree::new("preprocess", Level::Trace); + #[cfg(feature = "std")] let start = Instant::now(); let rate_bits = self.config.fri_config.rate_bits; let cap_height = self.config.fri_config.cap_height; @@ -879,6 +881,7 @@ impl, const D: usize> CircuitBuilder { }; timing.print(); + #[cfg(feature = "std")] debug!("Building circuit took {}s", start.elapsed().as_secs_f32()); CircuitData { prover_only, diff --git a/plonky2/src/plonk/get_challenges.rs b/plonky2/src/plonk/get_challenges.rs index f0635d93..4c4785f8 100644 --- a/plonky2/src/plonk/get_challenges.rs +++ b/plonky2/src/plonk/get_challenges.rs @@ -1,7 +1,7 @@ use alloc::vec; use alloc::vec::Vec; -use std::collections::HashSet; +use hashbrown::HashSet; use plonky2_field::extension::Extendable; use plonky2_field::polynomial::PolynomialCoeffs; diff --git a/plonky2/src/plonk/permutation_argument.rs b/plonky2/src/plonk/permutation_argument.rs index d9d56f93..2400516a 100644 --- a/plonky2/src/plonk/permutation_argument.rs +++ b/plonky2/src/plonk/permutation_argument.rs @@ -1,6 +1,6 @@ use alloc::vec::Vec; -use std::collections::HashMap; +use hashbrown::HashMap; use maybe_rayon::*; use plonky2_field::polynomial::PolynomialValues; use plonky2_field::types::Field; diff --git a/plonky2/src/plonk/proof.rs b/plonky2/src/plonk/proof.rs index a2a60d8c..a8a2f418 100644 --- a/plonky2/src/plonk/proof.rs +++ b/plonky2/src/plonk/proof.rs @@ -21,7 +21,9 @@ use crate::iop::target::Target; use crate::plonk::circuit_data::{CommonCircuitData, VerifierOnlyCircuitData}; use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::verifier::verify_with_challenges; -use crate::util::serialization::Buffer; +use crate::util::serialization::Write; +#[cfg(feature = "std")] +use crate::util::serialization::{Buffer, Read}; #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] #[serde(bound = "")] @@ -101,12 +103,13 @@ impl, C: GenericConfig, const D: usize> C::InnerHasher::hash_no_pad(&self.public_inputs) } - pub fn to_bytes(&self) -> anyhow::Result> { - let mut buffer = Buffer::new(Vec::new()); - buffer.write_proof_with_public_inputs(self)?; - Ok(buffer.bytes()) + pub fn to_bytes(&self) -> Vec { + let mut buffer = Vec::new(); + let _ = buffer.write_proof_with_public_inputs(self); + buffer } + #[cfg(feature = "std")] pub fn from_bytes( bytes: Vec, common_data: &CommonCircuitData, @@ -226,11 +229,10 @@ impl, C: GenericConfig, const D: usize> C::InnerHasher::hash_no_pad(&self.public_inputs) } - #[cfg(feature = "std")] - pub fn to_bytes(&self) -> anyhow::Result> { - let mut buffer = Buffer::new(Vec::new()); - buffer.write_compressed_proof_with_public_inputs(self)?; - Ok(buffer.bytes()) + pub fn to_bytes(&self) -> Vec { + let mut buffer = Vec::new(); + let _ = buffer.write_compressed_proof_with_public_inputs(self); + buffer } #[cfg(feature = "std")] diff --git a/plonky2/src/util/context_tree.rs b/plonky2/src/util/context_tree.rs index bc37b2b0..565e2d35 100644 --- a/plonky2/src/util/context_tree.rs +++ b/plonky2/src/util/context_tree.rs @@ -1,4 +1,4 @@ -use alloc::string::String; +use alloc::string::{String, ToString}; use alloc::vec; use alloc::vec::Vec; diff --git a/plonky2/src/util/mod.rs b/plonky2/src/util/mod.rs index d978dbc5..2c61b399 100644 --- a/plonky2/src/util/mod.rs +++ b/plonky2/src/util/mod.rs @@ -6,13 +6,12 @@ use plonky2_field::types::Field; pub(crate) mod context_tree; pub(crate) mod partial_products; + pub mod reducing; +pub mod serialization; pub mod strided_view; pub mod timing; -#[cfg(feature = "std")] -pub mod serialization; - pub(crate) fn transpose_poly_values(polys: Vec>) -> Vec> { let poly_values = polys.into_iter().map(|p| p.values).collect::>(); transpose(&poly_values) diff --git a/plonky2/src/util/serialization.rs b/plonky2/src/util/serialization.rs index de8961ae..b5c7ef5c 100644 --- a/plonky2/src/util/serialization.rs +++ b/plonky2/src/util/serialization.rs @@ -1,6 +1,10 @@ -use std::collections::HashMap; -use std::io::{Cursor, Read, Result, Write}; +use alloc::vec; +use alloc::vec::Vec; +use core::convert::Infallible; +#[cfg(feature = "std")] +use std::io::{self, Cursor, Read as _, Write as _}; +use hashbrown::HashMap; use plonky2_field::extension::{Extendable, FieldExtension}; use plonky2_field::polynomial::PolynomialCoeffs; use plonky2_field::types::{Field64, PrimeField64}; @@ -19,65 +23,72 @@ use crate::plonk::proof::{ CompressedProof, CompressedProofWithPublicInputs, OpeningSet, Proof, ProofWithPublicInputs, }; -#[derive(Debug)] -pub struct Buffer(Cursor>); +/// Buffer Position +pub trait Position { + /// Returns the position of the buffer. + fn position(&self) -> u64; +} -impl Buffer { - pub fn new(buffer: Vec) -> Self { - Self(Cursor::new(buffer)) - } +/// Buffer Size +pub trait Size { + /// Returns the length of `self`. + fn len(&self) -> usize; - pub fn len(&self) -> usize { - self.0.get_ref().len() - } - - pub fn is_empty(&self) -> bool { + /// Returns `true` if `self` has length zero. + #[inline] + fn is_empty(&self) -> bool { self.len() == 0 } +} - pub fn bytes(self) -> Vec { - self.0.into_inner() +impl Size for Vec { + #[inline] + fn len(&self) -> usize { + self.len() } +} - fn write_u8(&mut self, x: u8) -> Result<()> { - self.0.write_all(&[x]) - } - fn read_u8(&mut self) -> Result { - let mut buf = [0; std::mem::size_of::()]; - self.0.read_exact(&mut buf)?; +/// +pub trait Read { + /// + type Error; + + /// + fn read_exact(&mut self, bytes: &mut [u8]) -> Result<(), Self::Error>; + + /// + #[inline] + fn read_u8(&mut self) -> Result { + let mut buf = [0; core::mem::size_of::()]; + self.read_exact(&mut buf)?; Ok(buf[0]) } - fn write_u32(&mut self, x: u32) -> Result<()> { - self.0.write_all(&x.to_le_bytes()) - } - fn read_u32(&mut self) -> Result { - let mut buf = [0; std::mem::size_of::()]; - self.0.read_exact(&mut buf)?; + /// + #[inline] + fn read_u32(&mut self) -> Result { + let mut buf = [0; core::mem::size_of::()]; + self.read_exact(&mut buf)?; Ok(u32::from_le_bytes(buf)) } - fn write_field(&mut self, x: F) -> Result<()> { - self.0.write_all(&x.to_canonical_u64().to_le_bytes()) - } - fn read_field(&mut self) -> Result { - let mut buf = [0; std::mem::size_of::()]; - self.0.read_exact(&mut buf)?; + /// + #[inline] + fn read_field(&mut self) -> Result + where + F: Field64, + { + let mut buf = [0; core::mem::size_of::()]; + self.read_exact(&mut buf)?; Ok(F::from_canonical_u64(u64::from_le_bytes(buf))) } - fn write_field_ext, const D: usize>( - &mut self, - x: F::Extension, - ) -> Result<()> { - for &a in &x.to_basefield_array() { - self.write_field(a)?; - } - Ok(()) - } - fn read_field_ext, const D: usize>( - &mut self, - ) -> Result { + /// + #[inline] + fn read_field_ext(&mut self) -> Result + where + F: RichField + Extendable, + { let mut arr = [F::ZERO; D]; for a in arr.iter_mut() { *a = self.read_field()?; @@ -87,87 +98,66 @@ impl Buffer { )) } - fn write_hash>(&mut self, h: H::Hash) -> Result<()> { - self.0.write_all(&h.to_bytes()) - } - - fn read_hash>(&mut self) -> Result { + /// + #[inline] + fn read_hash(&mut self) -> Result + where + F: RichField, + H: Hasher, + { let mut buf = vec![0; H::HASH_SIZE]; - self.0.read_exact(&mut buf)?; + self.read_exact(&mut buf)?; Ok(H::Hash::from_bytes(&buf)) } - fn write_merkle_cap>( - &mut self, - cap: &MerkleCap, - ) -> Result<()> { - for &a in &cap.0 { - self.write_hash::(a)?; - } - Ok(()) - } - fn read_merkle_cap>( - &mut self, - cap_height: usize, - ) -> Result> { + /// + #[inline] + fn read_merkle_cap(&mut self, cap_height: usize) -> Result, Self::Error> + where + F: RichField, + H: Hasher, + { let cap_length = 1 << cap_height; Ok(MerkleCap( (0..cap_length) .map(|_| self.read_hash::()) - .collect::>>()?, + .collect::, _>>()?, )) } - pub fn write_field_vec(&mut self, v: &[F]) -> Result<()> { - for &a in v { - self.write_field(a)?; - } - Ok(()) - } - pub fn read_field_vec(&mut self, length: usize) -> Result> { + /// + #[inline] + fn read_field_vec(&mut self, length: usize) -> Result, Self::Error> + where + F: Field64, + { (0..length) .map(|_| self.read_field()) - .collect::>>() + .collect::, _>>() } - fn write_field_ext_vec, const D: usize>( - &mut self, - v: &[F::Extension], - ) -> Result<()> { - for &a in v { - self.write_field_ext::(a)?; - } - Ok(()) - } - fn read_field_ext_vec, const D: usize>( + /// + #[inline] + fn read_field_ext_vec( &mut self, length: usize, - ) -> Result> { - (0..length) - .map(|_| self.read_field_ext::()) - .collect::>>() + ) -> Result, Self::Error> + where + F: RichField + Extendable, + { + (0..length).map(|_| self.read_field_ext::()).collect() } - fn write_opening_set, const D: usize>( - &mut self, - os: &OpeningSet, - ) -> Result<()> { - self.write_field_ext_vec::(&os.constants)?; - self.write_field_ext_vec::(&os.plonk_sigmas)?; - self.write_field_ext_vec::(&os.wires)?; - self.write_field_ext_vec::(&os.plonk_zs)?; - self.write_field_ext_vec::(&os.plonk_zs_next)?; - self.write_field_ext_vec::(&os.partial_products)?; - self.write_field_ext_vec::(&os.quotient_polys) - } - fn read_opening_set< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( + /// + #[inline] + fn read_opening_set( &mut self, common_data: &CommonCircuitData, - ) -> Result> { + ) -> Result, Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { let config = &common_data.config; let constants = self.read_field_ext_vec::(common_data.num_constants)?; let plonk_sigmas = self.read_field_ext_vec::(config.num_routed_wires)?; @@ -190,52 +180,31 @@ impl Buffer { }) } - fn write_merkle_proof>( - &mut self, - p: &MerkleProof, - ) -> Result<()> { - let length = p.siblings.len(); - self.write_u8( - length - .try_into() - .expect("Merkle proof length must fit in u8."), - )?; - for &h in &p.siblings { - self.write_hash::(h)?; - } - Ok(()) - } - fn read_merkle_proof>(&mut self) -> Result> { + /// + #[inline] + fn read_merkle_proof(&mut self) -> Result, Self::Error> + where + F: RichField, + H: Hasher, + { let length = self.read_u8()?; Ok(MerkleProof { siblings: (0..length) .map(|_| self.read_hash::()) - .collect::>>()?, + .collect::, _>>()?, }) } - fn write_fri_initial_proof< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( - &mut self, - fitp: &FriInitialTreeProof, - ) -> Result<()> { - for (v, p) in &fitp.evals_proofs { - self.write_field_vec(v)?; - self.write_merkle_proof(p)?; - } - Ok(()) - } - fn read_fri_initial_proof< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( + /// + #[inline] + fn read_fri_initial_proof( &mut self, common_data: &CommonCircuitData, - ) -> Result> { + ) -> Result, Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { let config = &common_data.config; let salt = salt_size(common_data.fri_params.hiding); let mut evals_proofs = Vec::with_capacity(4); @@ -263,26 +232,17 @@ impl Buffer { Ok(FriInitialTreeProof { evals_proofs }) } - fn write_fri_query_step< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( - &mut self, - fqs: &FriQueryStep, - ) -> Result<()> { - self.write_field_ext_vec::(&fqs.evals)?; - self.write_merkle_proof(&fqs.merkle_proof) - } - fn read_fri_query_step< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( + /// + #[inline] + fn read_fri_query_step( &mut self, arity: usize, compressed: bool, - ) -> Result> { + ) -> Result, Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { let evals = self.read_field_ext_vec::(arity - usize::from(compressed))?; let merkle_proof = self.read_merkle_proof()?; Ok(FriQueryStep { @@ -291,30 +251,16 @@ impl Buffer { }) } - fn write_fri_query_rounds< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( - &mut self, - fqrs: &[FriQueryRound], - ) -> Result<()> { - for fqr in fqrs { - self.write_fri_initial_proof::(&fqr.initial_trees_proof)?; - for fqs in &fqr.steps { - self.write_fri_query_step::(fqs)?; - } - } - Ok(()) - } - fn read_fri_query_rounds< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( + /// + #[inline] + fn read_fri_query_rounds( &mut self, common_data: &CommonCircuitData, - ) -> Result>> { + ) -> Result>, Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { let config = &common_data.config; let mut fqrs = Vec::with_capacity(config.fri_config.num_query_rounds); for _ in 0..config.fri_config.num_query_rounds { @@ -324,7 +270,7 @@ impl Buffer { .reduction_arity_bits .iter() .map(|&ar| self.read_fri_query_step::(1 << ar, false)) - .collect::>()?; + .collect::>()?; fqrs.push(FriQueryRound { initial_trees_proof, steps, @@ -333,25 +279,20 @@ impl Buffer { Ok(fqrs) } - fn write_fri_proof, C: GenericConfig, const D: usize>( - &mut self, - fp: &FriProof, - ) -> Result<()> { - for cap in &fp.commit_phase_merkle_caps { - self.write_merkle_cap(cap)?; - } - self.write_fri_query_rounds::(&fp.query_round_proofs)?; - self.write_field_ext_vec::(&fp.final_poly.coeffs)?; - self.write_field(fp.pow_witness) - } - fn read_fri_proof, C: GenericConfig, const D: usize>( + /// + #[inline] + fn read_fri_proof( &mut self, common_data: &CommonCircuitData, - ) -> Result> { + ) -> Result, Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { let config = &common_data.config; let commit_phase_merkle_caps = (0..common_data.fri_params.reduction_arity_bits.len()) .map(|_| self.read_merkle_cap(config.fri_config.cap_height)) - .collect::>>()?; + .collect::, _>>()?; let query_round_proofs = self.read_fri_query_rounds::(common_data)?; let final_poly = PolynomialCoeffs::new( self.read_field_ext_vec::(common_data.fri_params.final_poly_len())?, @@ -365,27 +306,22 @@ impl Buffer { }) } - pub fn write_proof, C: GenericConfig, const D: usize>( - &mut self, - proof: &Proof, - ) -> Result<()> { - self.write_merkle_cap(&proof.wires_cap)?; - self.write_merkle_cap(&proof.plonk_zs_partial_products_cap)?; - self.write_merkle_cap(&proof.quotient_polys_cap)?; - self.write_opening_set(&proof.openings)?; - self.write_fri_proof::(&proof.opening_proof) - } - pub fn read_proof, C: GenericConfig, const D: usize>( + /// + #[inline] + fn read_proof( &mut self, common_data: &CommonCircuitData, - ) -> Result> { + ) -> Result, Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { let config = &common_data.config; let wires_cap = self.read_merkle_cap(config.fri_config.cap_height)?; let plonk_zs_partial_products_cap = self.read_merkle_cap(config.fri_config.cap_height)?; let quotient_polys_cap = self.read_merkle_cap(config.fri_config.cap_height)?; let openings = self.read_opening_set::(common_data)?; let opening_proof = self.read_fri_proof::(common_data)?; - Ok(Proof { wires_cap, plonk_zs_partial_products_cap, @@ -395,78 +331,41 @@ impl Buffer { }) } - pub fn write_proof_with_public_inputs< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( - &mut self, - proof_with_pis: &ProofWithPublicInputs, - ) -> Result<()> { - let ProofWithPublicInputs { - proof, - public_inputs, - } = proof_with_pis; - self.write_proof(proof)?; - self.write_field_vec(public_inputs) - } - pub fn read_proof_with_public_inputs< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( + /// + #[inline] + fn read_proof_with_public_inputs( &mut self, common_data: &CommonCircuitData, - ) -> Result> { + ) -> Result, Self::Error> + where + Self: Position + Size, + F: RichField + Extendable, + C: GenericConfig, + { let proof = self.read_proof(common_data)?; let public_inputs = self.read_field_vec( - (self.len() - self.0.position() as usize) / std::mem::size_of::(), + (self.len() - self.position() as usize) / core::mem::size_of::(), )?; - Ok(ProofWithPublicInputs { proof, public_inputs, }) } - fn write_compressed_fri_query_rounds< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( - &mut self, - cfqrs: &CompressedFriQueryRounds, - ) -> Result<()> { - for &i in &cfqrs.indices { - self.write_u32(i as u32)?; - } - - let mut initial_trees_proofs = cfqrs.initial_trees_proofs.iter().collect::>(); - initial_trees_proofs.sort_by_key(|&x| x.0); - for (_, itp) in initial_trees_proofs { - self.write_fri_initial_proof::(itp)?; - } - for h in &cfqrs.steps { - let mut fri_query_steps = h.iter().collect::>(); - fri_query_steps.sort_by_key(|&x| x.0); - for (_, fqs) in fri_query_steps { - self.write_fri_query_step::(fqs)?; - } - } - Ok(()) - } - fn read_compressed_fri_query_rounds< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( + /// + #[inline] + fn read_compressed_fri_query_rounds( &mut self, common_data: &CommonCircuitData, - ) -> Result> { + ) -> Result, Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { let config = &common_data.config; let original_indices = (0..config.fri_config.num_query_rounds) .map(|_| self.read_u32().map(|i| i as usize)) - .collect::>>()?; + .collect::, _>>()?; let mut indices = original_indices.clone(); indices.sort_unstable(); indices.dedup(); @@ -484,7 +383,7 @@ impl Buffer { indices.dedup(); let query_steps = (0..indices.len()) .map(|_| self.read_fri_query_step::(1 << a, true)) - .collect::>>()?; + .collect::, _>>()?; steps.push( indices .iter() @@ -501,33 +400,20 @@ impl Buffer { }) } - fn write_compressed_fri_proof< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( - &mut self, - fp: &CompressedFriProof, - ) -> Result<()> { - for cap in &fp.commit_phase_merkle_caps { - self.write_merkle_cap(cap)?; - } - self.write_compressed_fri_query_rounds::(&fp.query_round_proofs)?; - self.write_field_ext_vec::(&fp.final_poly.coeffs)?; - self.write_field(fp.pow_witness) - } - fn read_compressed_fri_proof< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( + /// + #[inline] + fn read_compressed_fri_proof( &mut self, common_data: &CommonCircuitData, - ) -> Result> { + ) -> Result, Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { let config = &common_data.config; let commit_phase_merkle_caps = (0..common_data.fri_params.reduction_arity_bits.len()) .map(|_| self.read_merkle_cap(config.fri_config.cap_height)) - .collect::>>()?; + .collect::, _>>()?; let query_round_proofs = self.read_compressed_fri_query_rounds::(common_data)?; let final_poly = PolynomialCoeffs::new( self.read_field_ext_vec::(common_data.fri_params.final_poly_len())?, @@ -541,35 +427,22 @@ impl Buffer { }) } - pub fn write_compressed_proof< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( - &mut self, - proof: &CompressedProof, - ) -> Result<()> { - self.write_merkle_cap(&proof.wires_cap)?; - self.write_merkle_cap(&proof.plonk_zs_partial_products_cap)?; - self.write_merkle_cap(&proof.quotient_polys_cap)?; - self.write_opening_set(&proof.openings)?; - self.write_compressed_fri_proof::(&proof.opening_proof) - } - pub fn read_compressed_proof< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( + /// + #[inline] + fn read_compressed_proof( &mut self, common_data: &CommonCircuitData, - ) -> Result> { + ) -> Result, Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { let config = &common_data.config; let wires_cap = self.read_merkle_cap(config.fri_config.cap_height)?; let plonk_zs_partial_products_cap = self.read_merkle_cap(config.fri_config.cap_height)?; let quotient_polys_cap = self.read_merkle_cap(config.fri_config.cap_height)?; let openings = self.read_opening_set::(common_data)?; let opening_proof = self.read_compressed_fri_proof::(common_data)?; - Ok(CompressedProof { wires_cap, plonk_zs_partial_products_cap, @@ -579,14 +452,332 @@ impl Buffer { }) } - pub fn write_compressed_proof_with_public_inputs< + /// + #[inline] + fn read_compressed_proof_with_public_inputs( + &mut self, + common_data: &CommonCircuitData, + ) -> Result, Self::Error> + where + Self: Position + Size, F: RichField + Extendable, C: GenericConfig, - const D: usize, - >( + { + let proof = self.read_compressed_proof(common_data)?; + let public_inputs = self.read_field_vec( + (self.len() - self.position() as usize) / core::mem::size_of::(), + )?; + Ok(CompressedProofWithPublicInputs { + proof, + public_inputs, + }) + } +} + +/// +pub trait Write { + /// + type Error; + + /// + fn write_all(&mut self, bytes: &[u8]) -> Result<(), Self::Error>; + + /// + #[inline] + fn write_u8(&mut self, x: u8) -> Result<(), Self::Error> { + self.write_all(&[x]) + } + + /// + #[inline] + fn write_u32(&mut self, x: u32) -> Result<(), Self::Error> { + self.write_all(&x.to_le_bytes()) + } + + /// + #[inline] + fn write_field(&mut self, x: F) -> Result<(), Self::Error> + where + F: PrimeField64, + { + self.write_all(&x.to_canonical_u64().to_le_bytes()) + } + + /// + #[inline] + fn write_field_ext(&mut self, x: F::Extension) -> Result<(), Self::Error> + where + F: RichField + Extendable, + { + for &a in &x.to_basefield_array() { + self.write_field(a)?; + } + Ok(()) + } + + /// + #[inline] + fn write_hash(&mut self, h: H::Hash) -> Result<(), Self::Error> + where + F: RichField, + H: Hasher, + { + self.write_all(&h.to_bytes()) + } + + /// + #[inline] + fn write_merkle_cap(&mut self, cap: &MerkleCap) -> Result<(), Self::Error> + where + F: RichField, + H: Hasher, + { + for &a in &cap.0 { + self.write_hash::(a)?; + } + Ok(()) + } + + /// + #[inline] + fn write_field_vec(&mut self, v: &[F]) -> Result<(), Self::Error> + where + F: PrimeField64, + { + for &a in v { + self.write_field(a)?; + } + Ok(()) + } + + /// + #[inline] + fn write_field_ext_vec( + &mut self, + v: &[F::Extension], + ) -> Result<(), Self::Error> + where + F: RichField + Extendable, + { + for &a in v { + self.write_field_ext::(a)?; + } + Ok(()) + } + + /// + #[inline] + fn write_opening_set( + &mut self, + os: &OpeningSet, + ) -> Result<(), Self::Error> + where + F: RichField + Extendable, + { + self.write_field_ext_vec::(&os.constants)?; + self.write_field_ext_vec::(&os.plonk_sigmas)?; + self.write_field_ext_vec::(&os.wires)?; + self.write_field_ext_vec::(&os.plonk_zs)?; + self.write_field_ext_vec::(&os.plonk_zs_next)?; + self.write_field_ext_vec::(&os.partial_products)?; + self.write_field_ext_vec::(&os.quotient_polys) + } + + /// + #[inline] + fn write_merkle_proof(&mut self, p: &MerkleProof) -> Result<(), Self::Error> + where + F: RichField, + H: Hasher, + { + let length = p.siblings.len(); + self.write_u8( + length + .try_into() + .expect("Merkle proof length must fit in u8."), + )?; + for &h in &p.siblings { + self.write_hash::(h)?; + } + Ok(()) + } + + /// + #[inline] + fn write_fri_initial_proof( + &mut self, + fitp: &FriInitialTreeProof, + ) -> Result<(), Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { + for (v, p) in &fitp.evals_proofs { + self.write_field_vec(v)?; + self.write_merkle_proof(p)?; + } + Ok(()) + } + + /// + #[inline] + fn write_fri_query_step( + &mut self, + fqs: &FriQueryStep, + ) -> Result<(), Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { + self.write_field_ext_vec::(&fqs.evals)?; + self.write_merkle_proof(&fqs.merkle_proof) + } + + /// + #[inline] + fn write_fri_query_rounds( + &mut self, + fqrs: &[FriQueryRound], + ) -> Result<(), Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { + for fqr in fqrs { + self.write_fri_initial_proof::(&fqr.initial_trees_proof)?; + for fqs in &fqr.steps { + self.write_fri_query_step::(fqs)?; + } + } + Ok(()) + } + + /// + #[inline] + fn write_fri_proof( + &mut self, + fp: &FriProof, + ) -> Result<(), Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { + for cap in &fp.commit_phase_merkle_caps { + self.write_merkle_cap(cap)?; + } + self.write_fri_query_rounds::(&fp.query_round_proofs)?; + self.write_field_ext_vec::(&fp.final_poly.coeffs)?; + self.write_field(fp.pow_witness) + } + + /// + #[inline] + fn write_proof( + &mut self, + proof: &Proof, + ) -> Result<(), Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { + self.write_merkle_cap(&proof.wires_cap)?; + self.write_merkle_cap(&proof.plonk_zs_partial_products_cap)?; + self.write_merkle_cap(&proof.quotient_polys_cap)?; + self.write_opening_set(&proof.openings)?; + self.write_fri_proof::(&proof.opening_proof) + } + + /// + #[inline] + fn write_proof_with_public_inputs( + &mut self, + proof_with_pis: &ProofWithPublicInputs, + ) -> Result<(), Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { + let ProofWithPublicInputs { + proof, + public_inputs, + } = proof_with_pis; + self.write_proof(proof)?; + self.write_field_vec(public_inputs) + } + + /// + #[inline] + fn write_compressed_fri_query_rounds( + &mut self, + cfqrs: &CompressedFriQueryRounds, + ) -> Result<(), Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { + for &i in &cfqrs.indices { + self.write_u32(i as u32)?; + } + let mut initial_trees_proofs = cfqrs.initial_trees_proofs.iter().collect::>(); + initial_trees_proofs.sort_by_key(|&x| x.0); + for (_, itp) in initial_trees_proofs { + self.write_fri_initial_proof::(itp)?; + } + for h in &cfqrs.steps { + let mut fri_query_steps = h.iter().collect::>(); + fri_query_steps.sort_by_key(|&x| x.0); + for (_, fqs) in fri_query_steps { + self.write_fri_query_step::(fqs)?; + } + } + Ok(()) + } + + /// + #[inline] + fn write_compressed_fri_proof( + &mut self, + fp: &CompressedFriProof, + ) -> Result<(), Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { + for cap in &fp.commit_phase_merkle_caps { + self.write_merkle_cap(cap)?; + } + self.write_compressed_fri_query_rounds::(&fp.query_round_proofs)?; + self.write_field_ext_vec::(&fp.final_poly.coeffs)?; + self.write_field(fp.pow_witness) + } + + /// + #[inline] + fn write_compressed_proof( + &mut self, + proof: &CompressedProof, + ) -> Result<(), Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { + self.write_merkle_cap(&proof.wires_cap)?; + self.write_merkle_cap(&proof.plonk_zs_partial_products_cap)?; + self.write_merkle_cap(&proof.quotient_polys_cap)?; + self.write_opening_set(&proof.openings)?; + self.write_compressed_fri_proof::(&proof.opening_proof) + } + + /// + #[inline] + fn write_compressed_proof_with_public_inputs( &mut self, proof_with_pis: &CompressedProofWithPublicInputs, - ) -> Result<()> { + ) -> Result<(), Self::Error> + where + F: RichField + Extendable, + C: GenericConfig, + { let CompressedProofWithPublicInputs { proof, public_inputs, @@ -594,22 +785,70 @@ impl Buffer { self.write_compressed_proof(proof)?; self.write_field_vec(public_inputs) } - pub fn read_compressed_proof_with_public_inputs< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( - &mut self, - common_data: &CommonCircuitData, - ) -> Result> { - let proof = self.read_compressed_proof(common_data)?; - let public_inputs = self.read_field_vec( - (self.len() - self.0.position() as usize) / std::mem::size_of::(), - )?; +} - Ok(CompressedProofWithPublicInputs { - proof, - public_inputs, - }) +impl Write for Vec { + type Error = Infallible; + + #[inline] + fn write_all(&mut self, bytes: &[u8]) -> Result<(), Self::Error> { + self.extend_from_slice(bytes); + Ok(()) + } +} + +/// +#[cfg(feature = "std")] +#[derive(Debug)] +pub struct Buffer(Cursor>); + +#[cfg(feature = "std")] +impl Buffer { + /// + #[inline] + pub fn new(buffer: Vec) -> Self { + Self(Cursor::new(buffer)) + } + + /// + #[inline] + pub fn bytes(self) -> Vec { + self.0.into_inner() + } +} + +#[cfg(feature = "std")] +impl Size for Buffer { + #[inline] + fn len(&self) -> usize { + self.0.get_ref().len() + } +} + +#[cfg(feature = "std")] +impl Position for Buffer { + #[inline] + fn position(&self) -> u64 { + self.0.position() + } +} + +#[cfg(feature = "std")] +impl Read for Buffer { + type Error = io::Error; + + #[inline] + fn read_exact(&mut self, bytes: &mut [u8]) -> Result<(), Self::Error> { + self.0.read_exact(bytes) + } +} + +#[cfg(feature = "std")] +impl Write for Buffer { + type Error = io::Error; + + #[inline] + fn write_all(&mut self, bytes: &[u8]) -> Result<(), Self::Error> { + self.0.write_all(bytes) } } diff --git a/util/src/lib.rs b/util/src/lib.rs index 6662f7c0..c840bc69 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -10,6 +10,7 @@ extern crate alloc; use alloc::vec::Vec; use core::arch::asm; +use core::convert::Infallible; use core::hint::unreachable_unchecked; use core::mem::size_of; use core::ptr::{swap, swap_nonoverlapping}; @@ -18,6 +19,15 @@ use crate::transpose_util::transpose_in_place_square; mod transpose_util; +/// Converts `result` into the [`Ok`] variant of [`Result`]. +#[inline] +pub fn into_ok(result: Result) -> T { + match result { + Ok(value) => value, + _ => unreachable!("The `Infallible` value cannot be constructed."), + } +} + pub fn bits_u64(n: u64) -> usize { (64 - n.leading_zeros()) as usize }