From 3d6b65a54b0b3cfe888f5ece3c2206eef2751c93 Mon Sep 17 00:00:00 2001 From: M Alghazwi Date: Mon, 7 Jul 2025 21:56:39 +0200 Subject: [PATCH] add monolith hash --- plonky2-monolith/.gitignore | 13 + plonky2-monolith/Cargo.toml | 49 ++ plonky2-monolith/README.md | 97 ++++ plonky2-monolith/benches/allocator/mod.rs | 7 + plonky2-monolith/benches/base_proof.rs | 110 ++++ plonky2-monolith/benches/circuits/mod.rs | 96 ++++ plonky2-monolith/benches/hashing.rs | 66 +++ plonky2-monolith/benches/merkle.rs | 39 ++ plonky2-monolith/benches/recursion.rs | 302 ++++++++++ plonky2-monolith/src/gates/base_sum_custom.rs | 434 +++++++++++++++ plonky2-monolith/src/gates/gadget.rs | 391 +++++++++++++ plonky2-monolith/src/gates/mod.rs | 37 ++ plonky2-monolith/src/gates/monolith.rs | 488 +++++++++++++++++ plonky2-monolith/src/lib.rs | 9 + plonky2-monolith/src/monolith_hash/mod.rs | 436 +++++++++++++++ .../src/monolith_hash/monolith_goldilocks.rs | 517 ++++++++++++++++++ plonky2-monolith/tests/integration.rs | 69 +++ 17 files changed, 3160 insertions(+) create mode 100644 plonky2-monolith/.gitignore create mode 100644 plonky2-monolith/Cargo.toml create mode 100644 plonky2-monolith/README.md create mode 100644 plonky2-monolith/benches/allocator/mod.rs create mode 100644 plonky2-monolith/benches/base_proof.rs create mode 100644 plonky2-monolith/benches/circuits/mod.rs create mode 100644 plonky2-monolith/benches/hashing.rs create mode 100644 plonky2-monolith/benches/merkle.rs create mode 100644 plonky2-monolith/benches/recursion.rs create mode 100644 plonky2-monolith/src/gates/base_sum_custom.rs create mode 100644 plonky2-monolith/src/gates/gadget.rs create mode 100644 plonky2-monolith/src/gates/mod.rs create mode 100644 plonky2-monolith/src/gates/monolith.rs create mode 100644 plonky2-monolith/src/lib.rs create mode 100644 plonky2-monolith/src/monolith_hash/mod.rs create mode 100644 plonky2-monolith/src/monolith_hash/monolith_goldilocks.rs create mode 100644 plonky2-monolith/tests/integration.rs diff --git a/plonky2-monolith/.gitignore b/plonky2-monolith/.gitignore new file mode 100644 index 0000000..a3547b9 --- /dev/null +++ b/plonky2-monolith/.gitignore @@ -0,0 +1,13 @@ +#IDE Related +.idea + +# Cargo build +/target +Cargo.lock + +# Profile-guided optimization +/tmp +pgo-data.profdata + +# MacOS nuisances +.DS_Store diff --git a/plonky2-monolith/Cargo.toml b/plonky2-monolith/Cargo.toml new file mode 100644 index 0000000..3ce4ad6 --- /dev/null +++ b/plonky2-monolith/Cargo.toml @@ -0,0 +1,49 @@ +[package] +name = "plonky2_monolith" +version = "0.1.0" +description = "Use Monolith hash to generate Plonky2 proofs and to write Plonky2 circuits" +edition = "2021" +license-file = "LICENSE-APACHE" +readme = "README.md" +keywords = ["cryptography", "PLONK", "hash", "zero_knowledge"] +categories = ["cryptography"] + +[dependencies] +anyhow = "1.0.75" +itertools = "0.14.0" +plonky2 = {version = "1.0.2", default-features = true} +rand_chacha = "0.9.0" +serde = "1.0.188" +unroll = "0.1.5" + +[features] +default = ["default-sponge-params"] +default-sponge-params = [] + +[dev-dependencies] +log = "0.4.20" +rstest = "0.24.0" +serial_test = "3.2.0" +env_logger = "0.11.6" +criterion = "0.5.1" +tynm = "0.1.8" + +[target.'cfg(not(target_env = "msvc"))'.dev-dependencies] +jemallocator = "0.5.0" + +[[bench]] +name = "hashing" +harness = false + +[[bench]] +name = "merkle" +harness = false + +[[bench]] +name = "base_proof" +harness = false + +[[bench]] +name = "recursion" +harness = false + diff --git a/plonky2-monolith/README.md b/plonky2-monolith/README.md new file mode 100644 index 0000000..97a2c84 --- /dev/null +++ b/plonky2-monolith/README.md @@ -0,0 +1,97 @@ +# Monolith Plonky2 +This crate provides an implementation of the [Monolith hash function](https://eprint.iacr.org/2023/1025.pdf) that can be employed in the [Plonky2 proving system](https://github.com/mir-protocol/plonky2). Monolith hash function is a new zk-friendly hash function which is much faster than state-of-the-art zk-friendly hash functions, exhibiting performance similar to the Keccak hash function. In particular, according to our initial benchmarks, Monolith is from 2 to 3 times faster than Poseidon, the current hash function employed in the Plonky2 proving system. + +This crate can be employed to: + +- Generate Plonky2 proofs employing Monolith hash function +- Write Plonky2 circuits computing Monolith hashes, which is also useful to recursively verify Plonky2 proofs generated with Monolith. To this extent, this crate provides a Plonky2 gate for the Monolith permutation. + +The crate also provides benchmarks that compare the Monolith implementation and the Monolith gate with the corresponding Poseidon components currently employed in Plonky2. + +## Usage +Generate a proof employing Monolith hash function: +```rust +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::iop::witness::PartialWitness; +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::plonk::circuit_data::CircuitConfig; +use plonky2::field::types::Sample; +use plonky2::iop::witness::WitnessWrite; +use plonky2_monolith::monolith_hash::monolith_goldilocks::MonolithGoldilocksConfig; +use std::error::Error; + +const D: usize = 2; +type F = GoldilocksField; +fn main() -> Result<(), Box> { + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let init_t = builder.add_virtual_public_input(); + let mut res_t = builder.add_virtual_target(); + builder.connect(init_t, res_t); + for _ in 0..100 { + res_t = builder.mul(res_t, init_t); + } + builder.register_public_input(res_t); + let data = builder.build::(); + + let mut pw = PartialWitness::::new(); + let input = F::rand(); + pw.set_target(init_t, input); + + + let proof = data.prove(pw)?; + + Ok(data.verify(proof)?) +} +``` +Build a circuit employing Monolith gate: +```rust +use std::cmp; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::iop::witness::PartialWitness; +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::plonk::circuit_data::CircuitConfig; +use plonky2::field::types::Sample; +use plonky2::iop::witness::WitnessWrite; +use plonky2::gates::gate::Gate; +use plonky2::hash::hash_types::NUM_HASH_OUT_ELTS; +use plonky2_monolith::monolith_hash::monolith_goldilocks::MonolithGoldilocksConfig; +use plonky2_monolith::gates::monolith::MonolithGate; +use plonky2_monolith::monolith_hash::MonolithHash; +use std::error::Error; + +const D: usize = 2; +type F = GoldilocksField; + +fn generate_config_for_monolith() -> CircuitConfig { + let needed_wires = cmp::max(MonolithGate::::new().num_wires(), CircuitConfig::standard_recursion_config().num_wires); + CircuitConfig { + num_wires: needed_wires, + num_routed_wires: needed_wires, + ..CircuitConfig::standard_recursion_config() + } + } + +fn main() -> Result<(), Box> { + let config = generate_config_for_monolith(); + let mut builder = CircuitBuilder::::new(config); + let inp_targets_array = builder.add_virtual_target_arr::<{NUM_HASH_OUT_ELTS}>(); + let mut res_targets_array = inp_targets_array.clone(); + for _ in 0..100 { + res_targets_array = builder.hash_or_noop::(res_targets_array.to_vec()).elements; + } + builder.register_public_inputs(&res_targets_array); + let data = builder.build::(); + + + let mut pw = PartialWitness::::new(); + inp_targets_array.into_iter().for_each(|t| { + let input = F::rand(); + pw.set_target(t, input); + }); + + let proof = data.prove(pw)?; + + Ok(data.verify(proof)?) +} +``` diff --git a/plonky2-monolith/benches/allocator/mod.rs b/plonky2-monolith/benches/allocator/mod.rs new file mode 100644 index 0000000..441e5dc --- /dev/null +++ b/plonky2-monolith/benches/allocator/mod.rs @@ -0,0 +1,7 @@ +// Set up Jemalloc +#[cfg(not(target_env = "msvc"))] +use jemallocator::Jemalloc; + +#[cfg(not(target_env = "msvc"))] +#[global_allocator] +static GLOBAL: Jemalloc = Jemalloc; diff --git a/plonky2-monolith/benches/base_proof.rs b/plonky2-monolith/benches/base_proof.rs new file mode 100644 index 0000000..ffa97f4 --- /dev/null +++ b/plonky2-monolith/benches/base_proof.rs @@ -0,0 +1,110 @@ +use crate::circuits::BaseCircuit; +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use plonky2::field::extension::Extendable; +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::poseidon::PoseidonHash; +use plonky2::plonk::circuit_data::CircuitConfig; +use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher, PoseidonGoldilocksConfig}; +use plonky2_monolith::gates::generate_config_for_monolith_gate; +use plonky2_monolith::monolith_hash::monolith_goldilocks::MonolithGoldilocksConfig; +use plonky2_monolith::monolith_hash::{Monolith, MonolithHash}; +use tynm::type_name; + +mod circuits; + +macro_rules! pretty_print { + ($($arg:tt)*) => { + print!("\x1b[0;36mINFO ===========>\x1b[0m "); + println!($($arg)*); + } +} + +fn bench_base_proof< + F: RichField + Extendable + Monolith, + const D: usize, + C: GenericConfig, + H: Hasher + AlgebraicHasher, +>( + c: &mut Criterion, + config: CircuitConfig, +) { + let mut group = c.benchmark_group(&format!( + "base-proof<{}, {}>", + type_name::(), + type_name::() + )); + + for log_num_hashes in [10] { + group.bench_function( + format!("build circuit for 2^{} hashes", log_num_hashes).as_str(), + |b| { + b.iter_with_large_drop(|| { + BaseCircuit::::build_base_circuit(config.clone(), log_num_hashes); + }) + }, + ); + + let base_circuit = + BaseCircuit::::build_base_circuit(config.clone(), log_num_hashes); + + pretty_print!( + "circuit size: 2^{} gates", + base_circuit.get_circuit_data().common.degree_bits() + ); + + group.bench_function( + format!("prove circuit with 2^{} hashes", log_num_hashes).as_str(), + |b| { + b.iter_batched( + || F::rand(), + |init| base_circuit.generate_base_proof(init).unwrap(), + BatchSize::PerIteration, + ) + }, + ); + + let proof = base_circuit.generate_base_proof(F::rand()).unwrap(); + + pretty_print!("proof size: {}", proof.to_bytes().len()); + + group.bench_function( + format!("verify circuit with 2^{} hashes", log_num_hashes).as_str(), + |b| { + b.iter_batched( + || (base_circuit.get_circuit_data(), proof.clone()), + |(data, proof)| data.verify(proof).unwrap(), + BatchSize::PerIteration, + ) + }, + ); + } + + group.finish(); +} + +fn benchmark(c: &mut Criterion) { + const D: usize = 2; + type F = GoldilocksField; + bench_base_proof::( + c, + CircuitConfig::standard_recursion_config(), + ); + bench_base_proof::( + c, + CircuitConfig::standard_recursion_config(), + ); + bench_base_proof::( + c, + generate_config_for_monolith_gate::(), + ); + bench_base_proof::( + c, + generate_config_for_monolith_gate::(), + ); +} + +criterion_group!(name = benches; + config = Criterion::default().sample_size(10); + targets = benchmark); +criterion_main!(benches); diff --git a/plonky2-monolith/benches/circuits/mod.rs b/plonky2-monolith/benches/circuits/mod.rs new file mode 100644 index 0000000..3fd62b8 --- /dev/null +++ b/plonky2-monolith/benches/circuits/mod.rs @@ -0,0 +1,96 @@ +use anyhow::Result; +use plonky2::field::extension::Extendable; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::hashing::hash_n_to_m_no_pad; +use plonky2::iop::target::Target; +use plonky2::iop::witness::{PartialWitness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData}; +use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; +use plonky2::plonk::proof::ProofWithPublicInputs; +use std::marker::PhantomData; + +/// Data structure with all input/output targets and the `CircuitData` for the circuit proven +/// in base proofs. The circuit is designed to be representative of a common base circuit +/// operating on a common public state employing also some private data. +/// The computation performed on the state was chosen to employ commonly used gates, such as +/// arithmetic and hash ones +pub struct BaseCircuit< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, + H: Hasher + AlgebraicHasher, +> { + private_input: Target, + public_input: Target, + public_output: Target, + circuit_data: CircuitData, + num_powers: usize, + _hasher: PhantomData, +} + +impl< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, + H: Hasher + AlgebraicHasher, + > BaseCircuit +{ + pub fn build_base_circuit(config: CircuitConfig, log_num_hashes: usize) -> Self { + let num_hashes: usize = 1usize << log_num_hashes; + + let mut builder = CircuitBuilder::::new(config); + let mut res_t = builder.add_virtual_public_input(); + let init_t = res_t; + let zero = builder.zero(); + let to_be_hashed_t = builder.add_virtual_target(); + for _ in 0..num_hashes { + res_t = builder.mul(res_t, init_t); + res_t = builder.hash_n_to_m_no_pad::(vec![res_t, to_be_hashed_t, zero, zero], 1)[0]; + } + + let out_t = builder.add_virtual_public_input(); + let is_eq_t = builder.is_equal(out_t, res_t); + builder.assert_one(is_eq_t.target); + + let data = builder.build::(); + + Self { + private_input: to_be_hashed_t, + public_input: init_t, + public_output: out_t, + circuit_data: data, + num_powers: num_hashes, + _hasher: PhantomData::, + } + } + + pub fn generate_base_proof(&self, init: F) -> Result> { + let mut pw = PartialWitness::::new(); + + pw.set_target(self.public_input, init); + let to_be_hashed = F::rand(); + pw.set_target(self.private_input, to_be_hashed); + let mut res = init; + for _ in 0..self.num_powers { + res = res.mul(init); + res = + hash_n_to_m_no_pad::<_, H::Permutation>(&[res, to_be_hashed, F::ZERO, F::ZERO], 1) + [0]; + } + + pw.set_target(self.public_output, res); + + let proof = self.circuit_data.prove(pw)?; + + self.circuit_data.verify(proof.clone())?; + + assert_eq!(proof.public_inputs[1], res); + + Ok(proof) + } + + pub fn get_circuit_data(&self) -> &CircuitData { + &self.circuit_data + } +} diff --git a/plonky2-monolith/benches/hashing.rs b/plonky2-monolith/benches/hashing.rs new file mode 100644 index 0000000..b815101 --- /dev/null +++ b/plonky2-monolith/benches/hashing.rs @@ -0,0 +1,66 @@ +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::field::types::Sample; +use plonky2::hash::hash_types::{BytesHash, RichField}; +use plonky2::hash::keccak::KeccakHash; +use plonky2::hash::poseidon::{Poseidon, SPONGE_WIDTH}; +use plonky2::plonk::config::Hasher; +use plonky2_monolith::monolith_hash::Monolith; +use rand_chacha::rand_core::SeedableRng; +use rand_chacha::ChaCha12Rng; +use tynm::type_name; + +mod allocator; + +pub(crate) fn bench_keccak(c: &mut Criterion) { + let mut rng = ChaCha12Rng::seed_from_u64(38u64); + c.bench_function("keccak256", |b| { + b.iter_batched( + || { + ( + BytesHash::<32>::sample(&mut rng), + BytesHash::<32>::sample(&mut rng), + ) + }, + |(left, right)| as Hasher>::two_to_one(left, right), + BatchSize::SmallInput, + ) + }); +} + +pub(crate) fn bench_poseidon(c: &mut Criterion) { + c.bench_function( + &format!("poseidon<{}, {}>", type_name::(), SPONGE_WIDTH,), + |b| { + b.iter_batched( + || F::rand_array::(), + |state| F::poseidon(state), + BatchSize::SmallInput, + ) + }, + ); +} + +pub(crate) fn bench_monolith(c: &mut Criterion) { + c.bench_function( + &format!("monolith<{}, {}>", type_name::(), SPONGE_WIDTH,), + |b| { + b.iter_batched( + || F::rand_array::(), + |state| F::monolith(state), + BatchSize::SmallInput, + ) + }, + ); +} + +fn criterion_benchmark(c: &mut Criterion) { + bench_poseidon::(c); + bench_monolith::(c); + bench_keccak::(c); +} + +criterion_group!(name = benches; + config = Criterion::default().sample_size(500); + targets = criterion_benchmark); +criterion_main!(benches); diff --git a/plonky2-monolith/benches/merkle.rs b/plonky2-monolith/benches/merkle.rs new file mode 100644 index 0000000..8524879 --- /dev/null +++ b/plonky2-monolith/benches/merkle.rs @@ -0,0 +1,39 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::keccak::KeccakHash; +use plonky2::hash::merkle_tree::MerkleTree; +use plonky2::hash::poseidon::PoseidonHash; +use plonky2::plonk::config::Hasher; +use plonky2_monolith::monolith_hash::MonolithHash; +use tynm::type_name; + +mod allocator; + +const ELEMS_PER_LEAF: usize = 135; + +pub(crate) fn bench_merkle_tree>(c: &mut Criterion) { + let mut group = c.benchmark_group(&format!( + "merkle-tree<{}, {}>", + type_name::(), + type_name::() + )); + group.sample_size(10); + + for size_log in [13, 14, 15] { + let size = 1 << size_log; + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, _| { + let leaves = vec![F::rand_vec(ELEMS_PER_LEAF); size]; + b.iter(|| MerkleTree::::new(leaves.clone(), 0)); + }); + } +} + +fn criterion_benchmark(c: &mut Criterion) { + bench_merkle_tree::(c); + bench_merkle_tree::>(c); + bench_merkle_tree::(c); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/plonky2-monolith/benches/recursion.rs b/plonky2-monolith/benches/recursion.rs new file mode 100644 index 0000000..c6367d4 --- /dev/null +++ b/plonky2-monolith/benches/recursion.rs @@ -0,0 +1,302 @@ +use crate::circuits::BaseCircuit; +use anyhow::Result; +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use plonky2::field::extension::Extendable; +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::poseidon::PoseidonHash; +use plonky2::iop::witness::{PartialWitness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::{ + CircuitConfig, CircuitData, CommonCircuitData, VerifierCircuitTarget, +}; +use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, PoseidonGoldilocksConfig}; +use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; +use plonky2_monolith::gates::generate_config_for_monolith_gate; +use plonky2_monolith::monolith_hash::monolith_goldilocks::MonolithGoldilocksConfig; +use plonky2_monolith::monolith_hash::Monolith; +use std::marker::PhantomData; +use tynm::type_name; + +mod circuits; + +macro_rules! pretty_print { + ($($arg:tt)*) => { + print!("\x1b[0;36mINFO ===========>\x1b[0m "); + println!($($arg)*); + } +} + +/// Data structure with all input/output targets and the `CircuitData` for each circuit employed +/// to recursively shrink a proof up to the recursion threshold. The data structure contains a set +/// of targets and a `CircuitData` for each shrink step +struct ShrinkCircuit< + F: RichField + Extendable, + C: GenericConfig, + InnerC: GenericConfig, + const D: usize, +> { + proof_targets: Vec>, + circuit_data: Vec>, + inner_data: Vec, + _inner_c: PhantomData, +} + +impl< + F: RichField + Extendable, + C: GenericConfig, + InnerC: GenericConfig, + const D: usize, + > ShrinkCircuit +where + InnerC::Hasher: AlgebraicHasher, + C::Hasher: AlgebraicHasher, +{ + pub fn build_shrink_circuit( + inner_cd: &CommonCircuitData, + inner_config: CircuitConfig, + rec_config: CircuitConfig, + ) -> Self { + let mut circuit_data = inner_cd; + let mut shrink_circuit = Self { + proof_targets: Vec::new(), + circuit_data: Vec::new(), + inner_data: Vec::new(), + _inner_c: PhantomData::, + }; + loop { + let mut builder = if shrink_circuit.num_shrink_steps() > 0 { + CircuitBuilder::::new(rec_config.clone()) + } else { + CircuitBuilder::::new(inner_config.clone()) + }; + let pt = builder.add_virtual_proof_with_pis(circuit_data); + + let inner_data = + builder.add_virtual_verifier_data(circuit_data.config.fri_config.cap_height); + if shrink_circuit.num_shrink_steps() > 0 { + builder.verify_proof::(&pt, &inner_data, circuit_data); + } else { + builder.verify_proof::(&pt, &inner_data, circuit_data); + } + + for &pi_t in pt.public_inputs.iter() { + let t = builder.add_virtual_public_input(); + builder.connect(pi_t, t); + } + + let data = builder.build::(); + + shrink_circuit.proof_targets.push(pt); + shrink_circuit.circuit_data.push(data); + shrink_circuit.inner_data.push(inner_data); + circuit_data = &shrink_circuit.circuit_data.last().unwrap().common; + // we run the recursion until we get to a fixed circuit size, that is the + // `RECURSION_THRESHOLD`; this is necessary if endless recursion has to be applied, as + // the size of the circuit to be recursively verified should be fixed (as in the + // recursive circuits found in PR 883 https://github.com/mir-protocol/plonky2/pull/883 + // of Plonky2) + if circuit_data.degree_bits() == RECURSION_THRESHOLD { + break; + } + } + + shrink_circuit + } + + fn set_witness>( + pw: &mut PartialWitness, + proof: &ProofWithPublicInputs, + pt: &ProofWithPublicInputsTarget, + inner_data: &VerifierCircuitTarget, + circuit_data: &CircuitData, + ) where + GC::Hasher: AlgebraicHasher, + { + pw.set_proof_with_pis_target(pt, proof); + pw.set_cap_target( + &inner_data.constants_sigmas_cap, + &circuit_data.verifier_only.constants_sigmas_cap, + ); + pw.set_hash_target( + inner_data.circuit_digest, + circuit_data.verifier_only.circuit_digest, + ); + } + + pub fn shrink_proof<'a>( + &'a self, + inner_proof: ProofWithPublicInputs, + inner_cd: &'a CircuitData, + ) -> Result> { + let mut proof = None; + let mut circuit_data = None; + + for ((pt, cd), inner_data) in self + .proof_targets + .iter() + .zip(self.circuit_data.iter()) + .zip(self.inner_data.iter()) + { + let mut pw = PartialWitness::new(); + match (proof, circuit_data) { + (None, None) => Self::set_witness(&mut pw, &inner_proof, pt, inner_data, inner_cd), + (Some(inner_proof), Some(inner_cd)) => { + Self::set_witness(&mut pw, &inner_proof, pt, inner_data, inner_cd); + } + _ => unreachable!(), + } + proof = Some(cd.prove(pw)?); + circuit_data = Some(cd); + } + + Ok(proof.unwrap()) + } + + pub fn num_shrink_steps(&self) -> usize { + self.circuit_data.len() + } + + pub fn get_circuit_data(&self) -> &CircuitData { + self.circuit_data.last().unwrap() + } +} + +struct HashConfig> { + gen_config: PhantomData, + circuit_config: CircuitConfig, +} + +fn bench_recursive_proof< + F: RichField + Extendable + Monolith, + const D: usize, + const RECURSION_THRESHOLD: usize, + C: GenericConfig, + InnerC: GenericConfig, +>( + c: &mut Criterion, + rec_conf: &HashConfig, + inner_conf: &HashConfig, +) where + InnerC::Hasher: AlgebraicHasher, + C::Hasher: AlgebraicHasher, +{ + let mut group = c.benchmark_group(&format!( + "recursive-proof<{}, {}>", + type_name::(), + type_name::() + )); + + for log_num_hashes in [11, 13, 15] { + let base_circuit = BaseCircuit::::build_base_circuit( + CircuitConfig::standard_recursion_config(), + log_num_hashes, + ); + + let base_circuit_degree = base_circuit.get_circuit_data().common.degree_bits(); + + let proof = base_circuit.generate_base_proof(F::rand()).unwrap(); + + let inner_cd = &base_circuit.get_circuit_data().common; + + group.bench_function( + format!("build circuit for degree {}", base_circuit_degree).as_str(), + |b| { + b.iter_with_large_drop(|| { + ShrinkCircuit::::build_shrink_circuit::( + inner_cd, + inner_conf.circuit_config.clone(), + rec_conf.circuit_config.clone(), + ); + }) + }, + ); + + let shrink_circuit = + ShrinkCircuit::::build_shrink_circuit::( + inner_cd, + inner_conf.circuit_config.clone(), + rec_conf.circuit_config.clone(), + ); + + pretty_print!("shrink steps: {}", shrink_circuit.num_shrink_steps()); + + let inner_cd = base_circuit.get_circuit_data(); + + group.bench_function( + format!("shrinking proof of degree {}", base_circuit_degree).as_str(), + |b| { + b.iter_batched( + || proof.clone(), + |proof| shrink_circuit.shrink_proof(proof, inner_cd).unwrap(), + BatchSize::PerIteration, + ) + }, + ); + + let shrunk_proof = shrink_circuit.shrink_proof(proof, inner_cd).unwrap(); + let shrunk_cd = shrink_circuit.get_circuit_data(); + + assert_eq!(shrunk_cd.common.degree_bits(), RECURSION_THRESHOLD); + + group.bench_function( + format!("verify proof for degree {}", base_circuit_degree).as_str(), + |b| { + b.iter_batched( + || shrunk_proof.clone(), + |proof| shrunk_cd.verify(proof).unwrap(), + BatchSize::PerIteration, + ) + }, + ); + } + + group.finish(); +} + +const POSEIDON_RECURSION_THRESHOLD: usize = 12; +const MONOLITH_RECURSION_THRESHOLD: usize = 15; + +fn benchmark(c: &mut Criterion) { + const D: usize = 2; + type F = GoldilocksField; + let poseidon_config = HashConfig:: { + gen_config: PhantomData::default(), + circuit_config: CircuitConfig::standard_recursion_config(), + }; + let monolith_config = HashConfig:: { + gen_config: PhantomData::default(), + circuit_config: generate_config_for_monolith_gate::(), + }; + bench_recursive_proof::( + c, + &poseidon_config, + &poseidon_config, + ); + bench_recursive_proof::< + F, + D, + POSEIDON_RECURSION_THRESHOLD, + PoseidonGoldilocksConfig, + MonolithGoldilocksConfig, + >(c, &poseidon_config, &monolith_config); + bench_recursive_proof::< + F, + D, + MONOLITH_RECURSION_THRESHOLD, + MonolithGoldilocksConfig, + PoseidonGoldilocksConfig, + >(c, &monolith_config, &poseidon_config); + bench_recursive_proof::< + F, + D, + MONOLITH_RECURSION_THRESHOLD, + MonolithGoldilocksConfig, + MonolithGoldilocksConfig, + >(c, &monolith_config, &monolith_config); +} + +criterion_group!(name = benches; + config = Criterion::default().sample_size(10); + targets = benchmark); +criterion_main!(benches); diff --git a/plonky2-monolith/src/gates/base_sum_custom.rs b/plonky2-monolith/src/gates/base_sum_custom.rs new file mode 100644 index 0000000..1788b80 --- /dev/null +++ b/plonky2-monolith/src/gates/base_sum_custom.rs @@ -0,0 +1,434 @@ +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::field::types::{Field, Field64}; +use plonky2::gates::gate::Gate; +use plonky2::gates::packed_util::PackedEvaluableBase; +use plonky2::gates::util::StridedConstraintConsumer; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGeneratorRef}; +use plonky2::iop::target::Target; +use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::{CircuitConfig, CommonCircuitData}; +use plonky2::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_circuit}; +use plonky2::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBaseBatch, EvaluationVarsBasePacked, +}; +use plonky2::util::log_floor; +use plonky2::util::serialization::{Buffer, IoResult, Read, Write}; +use std::ops::Range; + +/// A gate which can decompose an element of `GoldilocksField` into base B little-endian limbs. +/// This gate is customized to be used for lookups of the Monolith hash function, and thus it has +/// the following differences w.r.t. the Plonky2 `BaseSum` gate: +/// - It allows to pack many decomposition operations on a single row +/// - It does not range-check each limb, since the lookup table to be applied on each limb will +/// already implicitly perform a range-check +/// - It supports the decomposition of any field element, while the `BaseSum` gate unpacks only +/// elements of at most `floor(log_B(F::order))` +#[derive(Copy, Clone, Debug)] +pub struct BaseSumCustomGate { + num_limbs: usize, + num_ops: usize, +} + +fn log_ceil(n: u64, base: u64) -> usize { + let res = log_floor(n, base); + if base.pow(res as u32) < n { + res + 1 + } else { + res + } +} + +impl BaseSumCustomGate { + /// Instantiate a new `BaseSumCustomGate` to decompose a Goldilocks field element in + /// `num_limbs` base B little-endian limbs. `config` allows to compute the number of operations + /// that can be performed with a single gate + pub fn new(num_limbs: usize, config: &CircuitConfig) -> Self { + let wires_per_op = Self::wires_per_op_from_limbs(num_limbs); + let num_ops = config.num_routed_wires / wires_per_op; + assert!( + num_ops > 0, + "cannot decompose in {} limbs with {} routed wires", + num_limbs, + config.num_routed_wires + ); + Self { num_limbs, num_ops } + } + + /// Instantiate a new `BaseSumCustomGate` employing the exact number of base B limbs necessary + /// to represent an arbitrary field element in `F` + pub fn new_from_config(config: &CircuitConfig) -> Self { + let num_limbs = + log_ceil(F::ORDER, B as u64).min(config.num_routed_wires - Self::START_LIMBS - 2); + Self::new(num_limbs, config) + } + + const WIRE_SUM: usize = 0; + const START_LIMBS: usize = 1; + + fn wires_per_op_from_limbs(num_limbs: usize) -> usize { + // num limbs + 1 wire for the element to be decomposed + 2 wires to range-check the + // field element obtained by re-composing the limbs + num_limbs + 1 + 2 + } + + fn wires_per_op(&self) -> usize { + Self::wires_per_op_from_limbs(self.num_limbs) + } + + /// Index of the wire storing the field element to be decomposed in the `i`-th operation of the + /// gate + pub fn ith_wire_sum(&self, i: usize) -> usize { + let wires_per_op = self.wires_per_op(); + i * wires_per_op + Self::WIRE_SUM + } + + /// Returns the index of the limb wires for the i-th operation of the gate. + pub fn ith_limbs(&self, i: usize) -> Range { + let wires_per_op = self.wires_per_op(); + (i * wires_per_op + Self::START_LIMBS) + ..(i * wires_per_op + Self::START_LIMBS + self.num_limbs) + } +} + +impl, const D: usize, const B: usize> Gate + for BaseSumCustomGate +{ + fn id(&self) -> String { + format!("{self:?} + Base: {B}") + } + + fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { + dst.write_usize(self.num_limbs)?; + dst.write_usize(self.num_ops) + } + + fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult { + let num_limbs = src.read_usize()?; + let num_ops = src.read_usize()?; + Ok(Self { num_limbs, num_ops }) + } + + fn num_ops(&self) -> usize { + self.num_ops + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_ops * 3); + for i in 0..self.num_ops { + // Splitting constraint + let sum = vars.local_wires[self.ith_wire_sum(i)]; + let limbs = vars.local_wires[self.ith_limbs(i)].to_vec(); + let computed_sum = reduce_with_powers(&limbs, F::Extension::from_canonical_usize(B)); + constraints.push(computed_sum - sum); + + // Boundary constraints + let z = vars.local_wires[self.ith_wire_sum(i) + self.num_limbs + 1]; + let z_prime = vars.local_wires[self.ith_wire_sum(i) + self.num_limbs + 2]; + + assert_eq!(limbs.len() % 2, 0); + + let base = F::Extension::from_canonical_usize(B); + let half_len = limbs.len() / 2; + let a = limbs + .iter() + .take(half_len - 1) + .rev() + .fold(limbs[half_len - 1], |acc, el| acc * base + *el); + let temp = limbs + .iter() + .rev() + .skip(1) + .take(half_len - 1) + .fold(limbs[limbs.len() - 1], |acc, el| acc * base + *el); + let b = temp - z; + + let two_32_m1 = F::Extension::from_canonical_usize(((1_u64 << 32) - 1) as usize); + + constraints.push(a * b); + constraints.push((z - two_32_m1) * z_prime - z); + } + + constraints + } + + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + self.eval_unfiltered_base_batch_packed(vars_base) + } + + fn eval_unfiltered_circuit( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_ops * 3); + for i in 0..self.num_ops { + // Splitting constraint + let base = builder.constant(F::from_canonical_usize(B)); + let sum = vars.local_wires[self.ith_wire_sum(i)]; + let limbs = vars.local_wires[self.ith_limbs(i)].to_vec(); + let computed_sum = reduce_with_powers_ext_circuit(builder, &limbs, base); + constraints.push(builder.sub_extension(computed_sum, sum)); + + // Boundary constraints + let z = vars.local_wires[self.ith_wire_sum(i) + self.num_limbs + 1]; + let z_prime = vars.local_wires[self.ith_wire_sum(i) + self.num_limbs + 2]; + + assert_eq!(limbs.len() % 2, 0); + + let base = F::from_canonical_usize(B); + let half_len = limbs.len() / 2; + let a = limbs + .iter() + .take(half_len - 1) + .rev() + .fold(limbs[half_len - 1], |acc, el| { + builder.mul_const_add_extension(base, acc, *el) + }); + let temp = limbs + .iter() + .rev() + .skip(1) + .take(half_len - 1) + .fold(limbs[limbs.len() - 1], |acc, el| { + builder.mul_const_add_extension(base, acc, *el) + }); + let b = builder.sub_extension(temp, z); + + let two_32_m1 = builder.constant_extension(F::Extension::from_canonical_usize( + ((1_u64 << 32) - 1) as usize, + )); + + let temp = builder.mul_extension(a, b); + constraints.push(temp); + + let mut temp = builder.sub_extension(z, two_32_m1); + temp = builder.mul_extension(temp, z_prime); + temp = builder.sub_extension(temp, z); + constraints.push(temp); + } + constraints + } + + fn generators(&self, row: usize, _local_constants: &[F]) -> Vec> { + (0..self.num_ops) + .map(|i| { + let gen = BaseSplitGenerator:: { + row, + num_limbs: self.num_limbs, + op: i, + }; + WitnessGeneratorRef::new(gen.adapter()) + }) + .collect() + } + + // 1 for the sum then `num_limbs` for the limbs. + // + 2 for the boundary constraints + fn num_wires(&self) -> usize { + (1 + self.num_limbs + 2) * self.num_ops + } + + fn num_constants(&self) -> usize { + 0 + } + + // 2 from boundary constraint of degree 2 + fn degree(&self) -> usize { + 2 + } + + // num_ops for the splitting, + 2 * num_ops for the boundary constraints + fn num_constraints(&self) -> usize { + 3 * self.num_ops + } +} + +impl, const D: usize, const B: usize> PackedEvaluableBase + for BaseSumCustomGate +{ + fn eval_unfiltered_base_packed>( + &self, + vars: EvaluationVarsBasePacked

, + mut yield_constr: StridedConstraintConsumer

, + ) { + for i in 0..self.num_ops { + // Splitting constraint + let sum = vars.local_wires[self.ith_wire_sum(i)]; + let limbs = vars.local_wires.view(self.ith_limbs(i)); + let computed_sum = reduce_with_powers(limbs, F::from_canonical_usize(B)); + + yield_constr.one(computed_sum - sum); + + // Boundary constraints + let z = vars.local_wires[self.ith_wire_sum(i) + self.num_limbs + 1]; + let z_prime = vars.local_wires[self.ith_wire_sum(i) + self.num_limbs + 2]; + + assert_eq!(limbs.len() % 2, 0); + + let base = F::from_canonical_usize(B); + let half_len = limbs.len() / 2; + let a = (0..half_len - 1) + .rev() + .fold(limbs[half_len - 1], |acc, i| acc * base + limbs[i]); + let temp = (half_len..limbs.len() - 1) + .rev() + .fold(limbs[limbs.len() - 1], |acc, i| acc * base + limbs[i]); + let b = temp - z; + + let two_32_m1 = F::from_canonical_usize(((1_u64 << 32) - 1) as usize); + + yield_constr.one(a * b); + yield_constr.one((z - two_32_m1) * z_prime - z); + } + } +} +/// Generator for each operation performed in a `BaseSumCustomGate`: it computes the limb +/// decomposition of the field element to be decomposed in the given operation +#[derive(Debug, Default)] +pub struct BaseSplitGenerator { + row: usize, + num_limbs: usize, + op: usize, +} + +impl BaseSplitGenerator { + pub(crate) fn new(row: usize, num_limbs: usize, op: usize) -> Self { + Self { row, num_limbs, op } + } + + fn wires_per_op(&self) -> usize { + BaseSumCustomGate::::wires_per_op_from_limbs(self.num_limbs) + } + + pub(crate) fn wire_sum(&self) -> Target { + Target::wire( + self.row, + self.wires_per_op() * self.op + BaseSumCustomGate::::WIRE_SUM, + ) + } + + pub(crate) fn limbs_wires(&self) -> Vec { + ((self.wires_per_op() * self.op + BaseSumCustomGate::::START_LIMBS) + ..(self.wires_per_op() * self.op + + BaseSumCustomGate::::START_LIMBS + + self.num_limbs)) + .map(|i| Target::wire(self.row, i)) + .collect() + } + + pub(crate) fn boundary_constraints_wires(&self) -> Vec { + ((self.wires_per_op() * self.op + BaseSumCustomGate::::START_LIMBS + self.num_limbs) + ..(self.wires_per_op() * self.op + + BaseSumCustomGate::::START_LIMBS + + self.num_limbs + + 2)) + .map(|i| Target::wire(self.row, i)) + .collect() + } +} + +impl, const B: usize, const D: usize> SimpleGenerator + for BaseSplitGenerator +{ + fn id(&self) -> String { + "BaseSplitRestrictGenerator".to_string() + } + + fn dependencies(&self) -> Vec { + vec![self.wire_sum()] + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let sum_value = witness.get_target(self.wire_sum()).to_canonical_u64() as usize; + debug_assert_eq!( + (0..self.num_limbs).fold(sum_value, |acc, _| acc / B), + 0, + "Integer too large to fit in given number of limbs" + ); + + let limbs = self.limbs_wires(); + let limbs_value = (0..self.num_limbs) + .zip(limbs.iter()) + .scan(sum_value, |acc, (_, t)| { + let tmp = F::from_canonical_usize(*acc % B); + *acc /= B; + out_buffer.set_target(*t, tmp); + Some(tmp) + }) + .collect::>(); + + assert_eq!(limbs_value.len() % 2, 0); + + let base = F::from_canonical_usize(B); + let half_len = limbs_value.len() / 2; + let a = limbs_value + .iter() + .take(half_len - 1) + .rev() + .fold(limbs_value[half_len - 1], |acc, el| acc * base + *el); + let b = limbs_value + .iter() + .rev() + .skip(1) + .take(half_len - 1) + .fold(limbs_value[limbs.len() - 1], |acc, el| acc * base + *el); + + let z_field = if a == F::ZERO { F::ONE } else { b }; + let z_prime_field = + F::inverse(&(z_field - F::from_canonical_u64(1_u64 << 32) + F::ONE)) * z_field; + out_buffer.set_target(self.boundary_constraints_wires()[0], z_field); + out_buffer.set_target(self.boundary_constraints_wires()[1], z_prime_field); + + assert_eq!( + z_prime_field * (z_field - F::from_canonical_u64(1_u64 << 32) + F::ONE), + z_field + ); + } + + fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { + dst.write_usize(self.row)?; + dst.write_usize(self.num_limbs)?; + dst.write_usize(self.op) + } + + fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult { + let row = src.read_usize()?; + let num_limbs = src.read_usize()?; + let op = src.read_usize()?; + Ok(Self { row, num_limbs, op }) + } +} + +#[cfg(test)] +mod tests { + use crate::gates::base_sum_custom::BaseSumCustomGate; + use crate::monolith_hash::{LOOKUP_NUM_LIMBS, LOOKUP_SIZE}; + use anyhow::Result; + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + #[test] + fn low_degree() { + test_low_degree::(BaseSumCustomGate::<{ LOOKUP_SIZE }>::new( + LOOKUP_NUM_LIMBS, + &CircuitConfig::standard_recursion_config(), + )) + } + + #[test] + fn eval_fns() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(BaseSumCustomGate::<{ LOOKUP_SIZE }>::new( + LOOKUP_NUM_LIMBS, + &CircuitConfig::standard_recursion_config(), + )) + } +} diff --git a/plonky2-monolith/src/gates/gadget.rs b/plonky2-monolith/src/gates/gadget.rs new file mode 100644 index 0000000..c66c7ba --- /dev/null +++ b/plonky2-monolith/src/gates/gadget.rs @@ -0,0 +1,391 @@ +use crate::gates::base_sum_custom::{BaseSplitGenerator, BaseSumCustomGate}; +use crate::gates::monolith::MonolithGate; +use crate::monolith_hash::{ + Monolith, MonolithHash, MonolithPermutation, LOOKUP_BITS, LOOKUP_NUM_LIMBS, LOOKUP_SIZE, + NUM_BARS, N_ROUNDS, SPONGE_WIDTH, +}; +use plonky2::field::extension::Extendable; +use plonky2::gates::lookup_table::LookupTable; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::hashing::PlonkyPermutation; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::CommonCircuitData; +use plonky2::plonk::config::AlgebraicHasher; +use plonky2::util::serialization::{Buffer, IoResult}; +use std::sync::Arc; + +/// `SplitAndLookup` provides a method to perform the following operation in a Plonky2 circuit: +/// 1) Split the input element into a list of targets, where each one represents a +/// base B limb of the element, with little-endian ordering +/// 2) Applies a lookup table (which should be defined only over base B input values) to each element +/// 3) Composes the final target using the outputs of the lookup table +pub trait SplitAndLookup { + /// Split and lookup functionality: `x` is the input element, `num_limbs` the number of base `B` + /// limbs `x` has to be split into, and `lut_index` identifies the lookup table to be applied, + /// which is assumed to have been already added to the set of lookup tables of the circuit + fn split_le_lookup(&mut self, x: Target, num_limbs: usize, lut_index: usize) -> Target; +} + +impl, const D: usize, const B: usize> SplitAndLookup + for CircuitBuilder +{ + fn split_le_lookup(&mut self, x: Target, num_limbs: usize, lut_index: usize) -> Target { + // Split into individual targets (decompose) + let gate_type = BaseSumCustomGate::::new(num_limbs, &self.config); + let (gate, i) = self.find_slot(gate_type, &[F::from_canonical_usize(num_limbs)], &[]); + let sum = Target::wire(gate, gate_type.ith_wire_sum(i)); + self.connect(x, sum); + + let split_targets_in = Target::wires_from_range(gate, gate_type.ith_limbs(i)); + + // Apply lookups + let mut split_targets_out = vec![]; + for i in 0..num_limbs { + split_targets_out.push(self.add_lookup_from_index(split_targets_in[i], lut_index)); + } + + // Get final output target (compose) + let limbs = split_targets_out; + + let (row, i) = self.find_slot(gate_type, &[F::from_canonical_usize(num_limbs)], &[]); + for (limb, wire) in limbs.iter().zip(gate_type.ith_limbs(i)) { + self.connect(*limb, Target::wire(row, wire)); + } + + self.add_simple_generator(BaseSumCustomRestrictGenerator::( + BaseSplitGenerator::new(row, num_limbs, i), + )); + + Target::wire(row, gate_type.ith_wire_sum(i)) + } +} + +#[derive(Debug, Default)] +struct BaseSumCustomRestrictGenerator(BaseSplitGenerator); + +impl, const B: usize, const D: usize> SimpleGenerator + for BaseSumCustomRestrictGenerator +{ + fn id(&self) -> String { + "BaseSumCustomRestrictGenerator".to_string() + } + + fn dependencies(&self) -> Vec { + self.0.limbs_wires() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let sum = self + .0 + .limbs_wires() + .iter() + .map(|&t| witness.get_target(t)) + .rev() + .fold(F::ZERO, |acc, limb| acc * F::from_canonical_usize(B) + limb); + + out_buffer.set_target(self.0.wire_sum(), sum); + } + + fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { + self.0.serialize(dst, _common_data) + } + + fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult { + let gen = BaseSplitGenerator::deserialize(src, _common_data)?; + Ok(BaseSumCustomRestrictGenerator::(gen)) + } +} + +impl AlgebraicHasher for MonolithHash { + type AlgebraicPermutation = MonolithPermutation; + + fn permute_swapped( + inputs: Self::AlgebraicPermutation, + swap: BoolTarget, + builder: &mut CircuitBuilder, + ) -> Self::AlgebraicPermutation + where + F: RichField + Extendable, + { + let lut_index = add_monolith_lookup_table(builder); + let gate_type = MonolithGate::::new(); + let gate = builder.add_gate(gate_type, vec![]); + + let swap_wire = MonolithGate::::WIRE_SWAP; + let swap_wire = Target::wire(gate, swap_wire); + builder.connect(swap.target, swap_wire); + + // Route input wires. + let inputs = inputs.as_ref(); + for i in 0..SPONGE_WIDTH { + let in_wire = MonolithGate::::wire_input(i); + let in_wire = Target::wire(gate, in_wire); + builder.connect(inputs[i], in_wire); + } + + // Route lookup wires + for round_ctr in 0..N_ROUNDS { + for i in 0..NUM_BARS { + let target_input: Target = + Target::wire(gate, MonolithGate::::wire_concrete_out(round_ctr, i)); + let target_output = + Target::wire(gate, MonolithGate::::wire_bars_out(round_ctr, i)); + let target_should = SplitAndLookup::::split_le_lookup( + builder, + target_input, + LOOKUP_NUM_LIMBS, + lut_index, + ); // Assumes a single lookup table + builder.connect(target_output, target_should); + } + } + + // Collect output wires. + Self::AlgebraicPermutation::new( + (0..SPONGE_WIDTH).map(|i| Target::wire(gate, MonolithGate::::wire_output(i))), + ) + } +} + +pub(crate) fn add_monolith_lookup_table, const D: usize>( + builder: &mut CircuitBuilder, +) -> usize { + // Add lookup table for Monolith. To ensure that the big lookup-table of Monolith is computed + // and added to the builder only the first time this function is called, we employ a fake small + // lookup-table to the circuit builder: if such a fake table is not available, then we compute + // and add the big Monolith table; otherwise, we skip the computation of the Monolith table and + // we simply return its index + let fake_table: LookupTable = Arc::new(vec![(0u16, 0u16)]); + if let Some(idx) = builder.is_stored(fake_table.clone()) { + idx + 1 + } else { + let fake_idx = builder.add_lookup_table_from_pairs(fake_table); + // use fake lut in order to avoid errors when generating constraints + let zero = builder.zero(); + builder.add_lookup_from_index(zero, fake_idx); + let inp_table: [u16; LOOKUP_SIZE] = core::array::from_fn(|i| i as u16); + let idx = builder.add_lookup_table_from_fn( + |i| { + let limb = i; + match LOOKUP_BITS { + 8 => { + let limbl1 = ((!limb & 0x80) >> 7) | ((!limb & 0x7F) << 1); // Left rotation by 1 + let limbl2 = ((limb & 0xC0) >> 6) | ((limb & 0x3F) << 2); // Left rotation by 2 + let limbl3 = ((limb & 0xE0) >> 5) | ((limb & 0x1F) << 3); // Left rotation by 3 + + // y_i = x_i + (1 + x_{i+1}) * x_{i+2} * x_{i+3} + let tmp = limb ^ limbl1 & limbl2 & limbl3; + ((tmp & 0x80) >> 7) | ((tmp & 0x7F) << 1) + } + 16 => { + let limbl1 = ((!limb & 0x8000) >> 15) | ((!limb & 0x7FFF) << 1); // Left rotation by 1 + let limbl2 = ((limb & 0xC000) >> 14) | ((limb & 0x3FFF) << 2); // Left rotation by 2 + let limbl3 = ((limb & 0xE000) >> 13) | ((limb & 0x1FFF) << 3); // Left rotation by 3 + + // y_i = x_i + (1 + x_{i+1}) * x_{i+2} * x_{i+3} + let tmp = limb ^ limbl1 & limbl2 & limbl3; + ((tmp & 0x8000) >> 15) | ((tmp & 0x7FFF) << 1) // Final rotation + } + _ => { + panic!("Unsupported lookup size"); + } + } + }, + &inp_table, + ); + assert_eq!(fake_idx + 1, idx); + idx + } +} + +#[cfg(test)] +pub(crate) mod tests { + use crate::monolith_hash::{Monolith, MonolithHash, MonolithPermutation, SPONGE_WIDTH}; + use anyhow::Result; + use log::{info, Level}; + use plonky2::field::extension::Extendable; + use plonky2::hash::hash_types::RichField; + use plonky2::hash::hashing::PlonkyPermutation; + use plonky2::iop::target::Target; + use plonky2::iop::witness::{PartialWitness, WitnessWrite}; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData}; + use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; + use plonky2::plonk::proof::ProofWithPublicInputs; + use plonky2::plonk::prover::prove; + use plonky2::util::timing::TimingTree; + + pub(crate) fn test_monolith_hash_circuit< + F: RichField + Extendable + Monolith, + C: GenericConfig, + const D: usize, + >( + config: CircuitConfig, + ) { + let mut builder = CircuitBuilder::new(config); + + let inp_targets_array = builder.add_virtual_target_arr::(); + println!("Num input: {}", inp_targets_array.len()); + let inp_targets = MonolithPermutation::::new(inp_targets_array); + + let mut out_targets = + MonolithHash::permute_swapped(inp_targets, builder._false(), &mut builder); + let nr = 1024; + for i in 1..nr { + out_targets = + MonolithHash::permute_swapped(out_targets, builder._false(), &mut builder); + } + builder.register_public_inputs(out_targets.as_ref()); + builder.print_gate_counts(0); + + println!("Num wires: {}", builder.config.num_wires); + println!("Num routed wires: {}", builder.config.num_routed_wires); + + let now = std::time::Instant::now(); + let circuit = builder.build::(); + println!("[Build time] {:?} s", now.elapsed().as_secs()); + println!("Circuit degree bits: {}", circuit.common.degree_bits()); + + let permutation_inputs = (0..SPONGE_WIDTH) + .map(F::from_canonical_usize) + .collect::>(); + + let mut inputs = PartialWitness::new(); + inp_targets + .as_ref() + .iter() + .zip(permutation_inputs.iter()) + .for_each(|(t, val)| inputs.set_target(*t, *val)); + + let now = std::time::Instant::now(); + let proof = circuit.prove(inputs).unwrap(); + println!("[Prove time] {:?} s", now.elapsed().as_secs()); + println!("Proof size (bytes): {}", proof.to_bytes().len()); + + let expected_outputs: [F; SPONGE_WIDTH] = + F::monolith(permutation_inputs.try_into().unwrap()); + + proof + .public_inputs + .iter() + .zip(expected_outputs.iter()) + .for_each(|(v, out)| assert_eq!(*v, *out)); + + let now = std::time::Instant::now(); + circuit.verify(proof).unwrap(); + println!("[Verify time] {:?} ms", now.elapsed()); + } + + pub(crate) fn prove_circuit_with_hash< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, + H: Hasher + AlgebraicHasher, + >( + config: CircuitConfig, + num_ops: usize, + print_timing: bool, + ) -> Result<(CircuitData, ProofWithPublicInputs)> { + let mut builder = CircuitBuilder::::new(config); + let init_t = builder.add_virtual_public_input(); + let mut res_t = builder.add_virtual_target(); + builder.connect(init_t, res_t); + let hash_targets = (0..SPONGE_WIDTH - 1) + .map(|_| builder.add_virtual_target()) + .collect::>(); + for _ in 0..num_ops { + res_t = builder.mul(res_t, res_t); + let mut to_be_hashed_elements = vec![res_t]; + to_be_hashed_elements.extend_from_slice(hash_targets.as_slice()); + res_t = builder.hash_or_noop::(to_be_hashed_elements).elements[0] + } + let out_t = builder.add_virtual_public_input(); + let is_eq_t = builder.is_equal(out_t, res_t); + builder.assert_one(is_eq_t.target); + + let data = builder.build::(); + + let mut pw = PartialWitness::::new(); + let input = F::rand(); + pw.set_target(init_t, input); + + let input_hash_elements = hash_targets + .iter() + .map(|&hash_t| { + let elem = F::rand(); + pw.set_target(hash_t, elem); + elem + }) + .collect::>(); + + let mut res = input; + for _ in 0..num_ops { + res = res.mul(res); + let mut to_be_hashed_elements = vec![res]; + to_be_hashed_elements.extend_from_slice(input_hash_elements.as_slice()); + res = H::hash_no_pad(to_be_hashed_elements.as_slice()).elements[0] + } + + pw.set_target(out_t, res); + + let proof = if print_timing { + let mut timing = TimingTree::new("prove", Level::Debug); + let proof = prove(&data.prover_only, &data.common, pw, &mut timing)?; + timing.print(); + info!("proof size: {}", proof.to_bytes().len()); + proof + } else { + data.prove(pw)? + }; + + assert_eq!(proof.public_inputs[0], input); + assert_eq!(proof.public_inputs[1], res); + + Ok((data, proof)) + } + + pub(crate) fn recursive_proof< + F: RichField + Extendable, + C: GenericConfig, + InnerC: GenericConfig, + const D: usize, + >( + inner_proof: ProofWithPublicInputs, + inner_cd: &CircuitData, + config: &CircuitConfig, + ) -> Result<(CircuitData, ProofWithPublicInputs)> + where + C::Hasher: AlgebraicHasher, + InnerC::Hasher: AlgebraicHasher, + { + let mut builder = CircuitBuilder::::new(config.clone()); + let mut pw = PartialWitness::new(); + let pt = builder.add_virtual_proof_with_pis(&inner_cd.common); + pw.set_proof_with_pis_target(&pt, &inner_proof); + + let inner_data = + builder.add_virtual_verifier_data(inner_cd.common.config.fri_config.cap_height); + pw.set_cap_target( + &inner_data.constants_sigmas_cap, + &inner_cd.verifier_only.constants_sigmas_cap, + ); + pw.set_hash_target( + inner_data.circuit_digest, + inner_cd.verifier_only.circuit_digest, + ); + + for &pi_t in pt.public_inputs.iter() { + let t = builder.add_virtual_public_input(); + builder.connect(pi_t, t); + } + builder.verify_proof::(&pt, &inner_data, &inner_cd.common); + let data = builder.build::(); + + let proof = data.prove(pw)?; + + Ok((data, proof)) + } +} diff --git a/plonky2-monolith/src/gates/mod.rs b/plonky2-monolith/src/gates/mod.rs new file mode 100644 index 0000000..0cbed7c --- /dev/null +++ b/plonky2-monolith/src/gates/mod.rs @@ -0,0 +1,37 @@ +/// A gate employed to split a Goldilocks field element in limbs of +/// [`crate::monolith_hash::LOOKUP_BITS`], which is necessary to apply the lookup table encoding the +/// function to be applied in the `Bars` layer; the same gate is also employed to reconstruct a +/// Goldilocks field element from the limbs, after the evaluation of the lookup table to each limb. +/// The gate works similarly to the Plonky2 `BaseSum` gate, but it is customized to be employed +/// specifically for the Monolith permutation +pub mod base_sum_custom; +/// This module provides the methods necessary to compute hashes employing Monolith gate in a +/// Plonky2 circuit +pub mod gadget; +/// Monolith gate for Plonky2 circuits +pub mod monolith; + +use crate::{gates::monolith::MonolithGate, monolith_hash::Monolith}; +use plonky2::field::extension::Extendable; +use plonky2::gates::gate::Gate; +use plonky2::hash::hash_types::RichField; +use plonky2::plonk::circuit_data::CircuitConfig; +use std::cmp; + +/// This function provides the recommended circuit configuration to be employed when Monolith +/// permutations are computed inside a circuit with the Monolith gate +pub fn generate_config_for_monolith_gate< + F: RichField + Extendable + Monolith, + const D: usize, +>() -> CircuitConfig { + let needed_wires = cmp::max( + MonolithGate::::new().num_wires(), + CircuitConfig::standard_recursion_config().num_wires, + ); + println!("num of wires = {}", needed_wires); + CircuitConfig { + num_wires: needed_wires, + num_routed_wires: needed_wires, + ..CircuitConfig::standard_recursion_config() + } +} diff --git a/plonky2-monolith/src/gates/monolith.rs b/plonky2-monolith/src/gates/monolith.rs new file mode 100644 index 0000000..ee0d6e4 --- /dev/null +++ b/plonky2-monolith/src/gates/monolith.rs @@ -0,0 +1,488 @@ +use crate::monolith_hash::{Monolith, NUM_BARS, N_ROUNDS, SPONGE_WIDTH}; +use itertools::Itertools; +use plonky2::field::extension::Extendable; +use plonky2::field::types::Field; +use plonky2::gates::gate::Gate; +use plonky2::gates::util::StridedConstraintConsumer; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::iop::generator::{GeneratedValues, WitnessGenerator, WitnessGeneratorRef}; +use plonky2::iop::target::Target; +use plonky2::iop::wire::Wire; +use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::CommonCircuitData; +use plonky2::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use plonky2::util::serialization::{Buffer, IoResult, Read, Write}; +use std::marker::PhantomData; + +/// Evaluates a full Monolith permutation with 12 state elements. +/// +/// This also has some extra features to make it suitable for efficiently verifying Merkle proofs. +/// It has a flag which can be used to swap the first four inputs with the next four, for ordering +/// sibling digests. +#[derive(Debug, Default, Clone, Eq, PartialEq)] +pub struct MonolithGate, const D: usize>(PhantomData); + +impl, const D: usize> MonolithGate { + /// Instantiate a new `MonolithGate` + pub fn new() -> Self { + Self(PhantomData) + } + + /// The wire index for the `i`th input to the permutation. + pub fn wire_input(i: usize) -> usize { + i + } + + /// The wire index for the `i`th output to the permutation. + pub fn wire_output(i: usize) -> usize { + SPONGE_WIDTH + i + } + + /// If this is set to 1, the first four inputs will be swapped with the next four inputs. This + /// is useful for ordering hashes in Merkle proofs. Otherwise, this should be set to 0. + pub const WIRE_SWAP: usize = 2 * SPONGE_WIDTH; + + const START_DELTA: usize = 2 * SPONGE_WIDTH + 1; + + /// A wire which stores `swap * (input[i + 4] - input[i])`; used to compute the swapped inputs. + fn wire_delta(i: usize) -> usize { + assert!(i < 4); + Self::START_DELTA + i + } + + const START_PERM: usize = Self::START_DELTA + 4; + + /// A wire which stores the output of the `i`-th Concrete of the `round`-th round + pub fn wire_concrete_out(round: usize, i: usize) -> usize { + // Configuration: + // 1 Concrete_out for each state element + // 1 Bar_out for each state element which goes through Bars + // = STATE_SIZE + NUM_BARS cells for each round + match round { + 0 => { + debug_assert!(round == 0); + debug_assert!(i < NUM_BARS); + Self::START_PERM + i + } + _ => { + debug_assert!(round > 0); + debug_assert!(i < SPONGE_WIDTH); + Self::START_PERM + (NUM_BARS * 2) + (SPONGE_WIDTH + NUM_BARS) * (round - 1) + i + } + } + } + + /// A wire which stores the output of the `i`-th Bar of the `round`-th round + pub fn wire_bars_out(round: usize, i: usize) -> usize { + debug_assert!(i < NUM_BARS); + Self::START_PERM + NUM_BARS + (SPONGE_WIDTH + NUM_BARS) * round + i + } + + /// End of wire indices, exclusive. + fn end() -> usize { + Self::START_PERM + (NUM_BARS * 2) + (SPONGE_WIDTH + NUM_BARS) * (N_ROUNDS - 1) + } +} + +impl + Monolith, const D: usize> Gate for MonolithGate { + fn id(&self) -> String { + format!("{self:?}") + } + + fn serialize( + &self, + _dst: &mut Vec, + _common_data: &CommonCircuitData, + ) -> IoResult<()> { + Ok(()) + } + + fn deserialize(_src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult { + Ok(MonolithGate::new()) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + // Assert that `swap` is binary. + let swap = vars.local_wires[Self::WIRE_SWAP]; + constraints.push(swap * (swap - F::Extension::ONE)); + + // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. + for i in 0..4 { + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + let delta_i = vars.local_wires[Self::wire_delta(i)]; + constraints.push(swap * (input_rhs - input_lhs) - delta_i); + } + + // Compute the possibly-swapped input layer. + let mut state = [F::Extension::ZERO; SPONGE_WIDTH]; + for i in 0..4 { + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let input_lhs = Self::wire_input(i); + let input_rhs = Self::wire_input(i + 4); + state[i] = vars.local_wires[input_lhs] + delta_i; + state[i + 4] = vars.local_wires[input_rhs] - delta_i; + } + for i in 8..SPONGE_WIDTH { + state[i] = vars.local_wires[Self::wire_input(i)]; + } + + // Permutation + ::concrete_field(&mut state, &::ROUND_CONSTANTS[0]); + for (round_ctr, rc) in ::ROUND_CONSTANTS.iter().skip(1).enumerate() { + // Check values after Concrete and set new state after applying bars + let loop_end = match round_ctr { + 0 => NUM_BARS, + _ => SPONGE_WIDTH, + }; + for i in 0..loop_end { + let concrete_out = vars.local_wires[Self::wire_concrete_out(round_ctr, i)]; + constraints.push(state[i] - concrete_out); + // Get values after Bars (this assumes lookups have already been applied, i.e., these are the outputs of Bars) + if i < NUM_BARS { + state[i] = vars.local_wires[Self::wire_bars_out(round_ctr, i)]; + } else { + state[i] = concrete_out; + } + } + + // Bricks + Concrete + ::bricks_field(&mut state); + ::concrete_field(&mut state, rc); + } + + // Final + for i in 0..SPONGE_WIDTH { + constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]); + } + + constraints + } + + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { + // Assert that `swap` is binary. + let swap = vars.local_wires[Self::WIRE_SWAP]; + yield_constr.one(swap * swap.sub_one()); + + // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. + for i in 0..4 { + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + let delta_i = vars.local_wires[Self::wire_delta(i)]; + yield_constr.one(swap * (input_rhs - input_lhs) - delta_i); + } + + // Compute the possibly-swapped input layer. + let mut state = [F::ZERO; SPONGE_WIDTH]; + for i in 0..4 { + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let input_lhs = Self::wire_input(i); + let input_rhs = Self::wire_input(i + 4); + state[i] = vars.local_wires[input_lhs] + delta_i; + state[i + 4] = vars.local_wires[input_rhs] - delta_i; + } + for i in 8..SPONGE_WIDTH { + state[i] = vars.local_wires[Self::wire_input(i)]; + } + + // Permutation + ::concrete(&mut state, &::ROUND_CONSTANTS[0]); + for (round_ctr, rc) in ::ROUND_CONSTANTS.iter().skip(1).enumerate() { + // Check values after Concrete and set new state after applying bars + let loop_end = match round_ctr { + 0 => NUM_BARS, + _ => SPONGE_WIDTH, + }; + for i in 0..loop_end { + let concrete_out = vars.local_wires[Self::wire_concrete_out(round_ctr, i)]; + yield_constr.one(state[i] - concrete_out); + // Get values after Bars (this assumes lookups have already been applied, i.e., these are the outputs of Bars) + if i < NUM_BARS { + state[i] = vars.local_wires[Self::wire_bars_out(round_ctr, i)]; + } else { + state[i] = concrete_out; + } + } + + // Bricks + Concrete + ::bricks(&mut state); + ::concrete(&mut state, rc); + } + + // Final + for i in 0..SPONGE_WIDTH { + yield_constr.one(state[i] - vars.local_wires[Self::wire_output(i)]); + } + } + + fn eval_unfiltered_circuit( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + // Assert that `swap` is binary. + let swap = vars.local_wires[Self::WIRE_SWAP]; + constraints.push(builder.mul_sub_extension(swap, swap, swap)); + + // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. + for i in 0..4 { + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let diff = builder.sub_extension(input_rhs, input_lhs); + constraints.push(builder.mul_sub_extension(swap, diff, delta_i)); + } + + // Compute the possibly-swapped input layer. + let mut state = [builder.zero_extension(); SPONGE_WIDTH]; + for i in 0..4 { + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + state[i] = builder.add_extension(input_lhs, delta_i); + state[i + 4] = builder.sub_extension(input_rhs, delta_i); + } + for i in 8..SPONGE_WIDTH { + state[i] = vars.local_wires[Self::wire_input(i)]; + } + + // Permutation + ::concrete_circuit( + builder, + &mut state, + &::ROUND_CONSTANTS[0], + ); + for (round_ctr, rc) in ::ROUND_CONSTANTS.iter().skip(1).enumerate() { + // Check values after Concrete and set new state after applying bars + let loop_end = match round_ctr { + 0 => NUM_BARS, + _ => SPONGE_WIDTH, + }; + for i in 0..loop_end { + let concrete_out = vars.local_wires[Self::wire_concrete_out(round_ctr, i)]; + constraints.push(builder.sub_extension(state[i], concrete_out)); + // Get values after Bars (this assumes lookups have already been applied, i.e., these are the outputs of Bars) + if i < NUM_BARS { + state[i] = vars.local_wires[Self::wire_bars_out(round_ctr, i)]; + } else { + state[i] = concrete_out; + } + } + + // Get values after Bars (this assumes lookups have already been applied, i.e., these are the outputs of Bars) + for i in 0..NUM_BARS { + state[i] = vars.local_wires[Self::wire_bars_out(round_ctr, i)]; + } + + // Bricks + Concrete + ::bricks_circuit(builder, &mut state); + ::concrete_circuit(builder, &mut state, rc); + } + + // Final + for i in 0..SPONGE_WIDTH { + constraints + .push(builder.sub_extension(state[i], vars.local_wires[Self::wire_output(i)])); + } + + constraints + } + + fn generators(&self, row: usize, _local_constants: &[F]) -> Vec> { + let gen = MonolithGenerator:: { + row, + _phantom: PhantomData, + }; + vec![WitnessGeneratorRef::new(gen)] + } + + fn num_wires(&self) -> usize { + Self::end() + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 2 + } + + fn num_constraints(&self) -> usize { + NUM_BARS + SPONGE_WIDTH * (N_ROUNDS - 1) + SPONGE_WIDTH + 1 + 4 + } +} + +/// Generator for `MonolithGate` wires +#[derive(Debug, Default, Clone, Eq, PartialEq)] +pub struct MonolithGenerator + Monolith, const D: usize> { + row: usize, + _phantom: PhantomData, +} + +impl + Monolith, const D: usize> WitnessGenerator + for MonolithGenerator +{ + fn id(&self) -> String { + "MonolithGenerator".to_string() + } + + fn watch_list(&self) -> Vec { + (0..SPONGE_WIDTH) + .map(|i| MonolithGate::::wire_input(i)) + .chain(Some(MonolithGate::::WIRE_SWAP)) + .chain( + (0..N_ROUNDS) + .cartesian_product(0..NUM_BARS) + .map(|(round, i)| MonolithGate::::wire_bars_out(round, i)), + ) + .map(|column| Target::wire(self.row, column)) + .collect() + } + + fn run(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) -> bool { + let local_wire = |column| Wire { + row: self.row, + column, + }; + + let mut state = (0..SPONGE_WIDTH) + .map_while(|i| witness.try_get_wire(local_wire(MonolithGate::::wire_input(i)))) + .collect::>(); + // exit if some of the input wires have not been already computed + if state.len() < SPONGE_WIDTH { + return false; + } + + let swap_value = + if let Some(wire) = witness.try_get_wire(local_wire(MonolithGate::::WIRE_SWAP)) { + wire + } else { + return false; + }; + debug_assert!(swap_value == F::ZERO || swap_value == F::ONE); + + for i in 0..4 { + let delta_i = swap_value * (state[i + 4] - state[i]); + out_buffer.set_wire(local_wire(MonolithGate::::wire_delta(i)), delta_i); + } + + if swap_value == F::ONE { + for i in 0..4 { + state.swap(i, 4 + i); + } + } + + let mut state: [F; SPONGE_WIDTH] = state.try_into().unwrap(); + + // Permutation + ::concrete_field(&mut state, &::ROUND_CONSTANTS[0]); + for (round_ctr, rc) in ::ROUND_CONSTANTS.iter().skip(1).enumerate() { + // Set values after Concrete + let loop_end = match round_ctr { + 0 => NUM_BARS, + _ => SPONGE_WIDTH, + }; + for i in 0..loop_end { + out_buffer.set_wire( + local_wire(MonolithGate::::wire_concrete_out(round_ctr, i)), + state[i], + ); + } + + // Get values after Bars (this assumes lookups have already been applied, i.e., these are the outputs of Bars) + for i in 0..NUM_BARS { + state[i] = match witness.try_get_wire(local_wire( + MonolithGate::::wire_bars_out(round_ctr, i), + )) { + Some(value) => value, + None => return false, + }; + } + + // Bricks + Concrete + ::bricks_field(&mut state); + ::concrete_field(&mut state, rc); + } + + // Final + for i in 0..SPONGE_WIDTH { + out_buffer.set_wire(local_wire(MonolithGate::::wire_output(i)), state[i]); + } + + true + } + + fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { + dst.write_usize(self.row) + } + + fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult { + let row = src.read_usize()?; + Ok(Self { + row, + _phantom: PhantomData, + }) + } +} + +#[cfg(test)] +mod tests { + use crate::gates::monolith::MonolithGate; + use crate::monolith_hash::monolith_goldilocks::MonolithGoldilocksConfig; + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; + use plonky2::plonk::config::GenericConfig; + + #[test] + fn wire_indices() { + type F = GoldilocksField; + type Gate = MonolithGate; + + assert_eq!(Gate::wire_input(0), 0); + assert_eq!(Gate::wire_input(11), 11); + assert_eq!(Gate::wire_output(0), 12); + assert_eq!(Gate::wire_output(11), 23); + assert_eq!(Gate::WIRE_SWAP, 24); + assert_eq!(Gate::wire_delta(0), 25); + assert_eq!(Gate::wire_delta(3), 28); + assert_eq!(Gate::wire_concrete_out(0, 0), 29); + assert_eq!(Gate::wire_bars_out(0, 0), 33); + assert_eq!(Gate::wire_concrete_out(1, 0), 37); + assert_eq!(Gate::wire_bars_out(1, 0), 49); + assert_eq!(Gate::wire_concrete_out(2, 0), 53); + assert_eq!(Gate::wire_bars_out(2, 0), 65); + assert_eq!(Gate::wire_concrete_out(3, 0), 69); + assert_eq!(Gate::wire_bars_out(3, 0), 81); + assert_eq!(Gate::wire_concrete_out(4, 0), 85); + assert_eq!(Gate::wire_bars_out(4, 0), 97); + assert_eq!(Gate::wire_concrete_out(5, 0), 101); + assert_eq!(Gate::wire_bars_out(5, 0), 113); + } + + #[test] + fn low_degree() { + type F = GoldilocksField; + let gate = MonolithGate::::new(); + test_low_degree(gate) + } + + #[test] + fn eval_fns() { + const D: usize = 2; + type C = MonolithGoldilocksConfig; + type F = >::F; + let gate = MonolithGate::::new(); + test_eval_fns::(gate).unwrap(); + } +} diff --git a/plonky2-monolith/src/lib.rs b/plonky2-monolith/src/lib.rs new file mode 100644 index 0000000..0a08560 --- /dev/null +++ b/plonky2-monolith/src/lib.rs @@ -0,0 +1,9 @@ +#![warn(missing_docs)] +#![allow(clippy::needless_range_loop)] +#![doc = include_str!("../README.md")] + +/// Implementation of Monolith hash function and data structures to employ it in Plonky2 +pub mod monolith_hash; + +/// Implementation of a Plonky2 gate for Monolith and data structures to employ it in Plonky2 circuits +pub mod gates; diff --git a/plonky2-monolith/src/monolith_hash/mod.rs b/plonky2-monolith/src/monolith_hash/mod.rs new file mode 100644 index 0000000..e3ca54f --- /dev/null +++ b/plonky2-monolith/src/monolith_hash/mod.rs @@ -0,0 +1,436 @@ +use plonky2::field::extension::{Extendable, FieldExtension}; +use plonky2::field::types::PrimeField64; +use plonky2::hash::hash_types::{HashOut, RichField}; +use plonky2::hash::hashing::{compress, hash_n_to_hash_no_pad, PlonkyPermutation}; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::iop::target::Target; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::config::Hasher; +use std::fmt::Debug; + +use unroll::unroll_for_loops; + +/// Monolith implementation for Goldilocks prime field +pub mod monolith_goldilocks; + +// change these values and disable `default-sponge-params` feature if it is needed to change the +// default sponge parameters +#[cfg(not(feature = "default-sponge-params"))] +const CUSTOM_SPONGE_RATE: usize = 8; +#[cfg(not(feature = "default-sponge-params"))] +const CUSTOM_SPONGE_CAPACITY: usize = 4; + +/// This constant describes the number of elements in the outer part of the cryptographic sponge +/// function. +#[cfg(feature = "default-sponge-params")] +pub const SPONGE_RATE: usize = 8; +/// This constant describes the number of elements in the inner part of the cryptographic sponge +/// function. +#[cfg(feature = "default-sponge-params")] +pub const SPONGE_CAPACITY: usize = 4; + +#[cfg(not(feature = "default-sponge-params"))] +pub const SPONGE_RATE: usize = CUSTOM_SPONGE_RATE; +#[cfg(not(feature = "default-sponge-params"))] +pub const SPONGE_CAPACITY: usize = CUSTOM_SPONGE_CAPACITY; + +/// This is the number of elements which constitute the state of the internal permutation and the +/// cryptographic sponge function built from this permutation. +pub const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; +/// Number of state elements involved in the `Bars` layer +pub const NUM_BARS: usize = 4; + +// The number of full rounds and partial rounds is given by the +// calc_round_numbers.py script. They happen to be the same for both +// width 8 and width 12 with s-box x^7. +// +// NB: Changing any of these values will require regenerating all of +// the precomputed constant arrays in this file. +/// Number of rounds in Monolith permutations +pub const N_ROUNDS: usize = 6; +/// Bit-size of the domain of the lookup function applied in the `Bars` layer: a state element is +/// split in limbs of `LOOKUP_BITS` bits, and the lookup function is applied to each limb. +pub const LOOKUP_BITS: usize = 8; +/// Size of the domain of the lookup function applied in the `Bars` layer +pub const LOOKUP_SIZE: usize = 1 << LOOKUP_BITS; +/// Number of limbs necessary to represent a 64-bit state element +pub const LOOKUP_NUM_LIMBS: usize = 64 / LOOKUP_BITS; + +#[inline] +pub(crate) fn split(x: u128) -> (u64, u32) { + (x as u64, (x >> 64) as u32) +} + +// helper function to compute concrete layer. The function requires to provide a buffer with +// `SPONGE_WIDTH` elements initialized to 0 to compute the outcome of the layer +#[inline(always)] +#[unroll_for_loops] +fn concrete_u128_with_tmp_buffer( + state_u128: &[u128; SPONGE_WIDTH], + round_constants: &[u64; SPONGE_WIDTH], + res: &mut [u128; SPONGE_WIDTH], +) { + for row in 0..SPONGE_WIDTH { + for (column, input) in state_u128.iter().enumerate() { + res[row] += *input * (M::MAT_12[row][column] as u128); + } + res[row] += round_constants[row] as u128; + res[row] = M::from_noncanonical_u96(split(res[row])).to_noncanonical_u64() as u128; + } +} +/// `Monolith` trait provides all the functions necessary to perform a Monolith permutation +pub trait Monolith: PrimeField64 { + // Static data + /// Number of round constants employed in a full Monolith permutation + const N_ROUND_CONSTANTS: usize = SPONGE_WIDTH * (N_ROUNDS + 1); + /// All the round constants employed in a full Monolith permutation + const ROUND_CONSTANTS: [[u64; SPONGE_WIDTH]; N_ROUNDS + 1]; + /// This constant contains the first row of a circulant `SPONGE_WIDTH x SPONGE_WIDTH` MDS matrix + /// M. All of the remaining rows of M are rotations of this constant vector. A multiplication + /// by M is used in the affine layer of Monolith. + const MAT_12: [[u64; SPONGE_WIDTH]; SPONGE_WIDTH]; + + /// Compute the "Bar" component + /// element is split in (16-bit lookups, analogous for 8-bit lookups): + /// [x_3 || x_2 || x_1 || x_0], where x_i is 16 bits large + /// element = 2^48 * x_3 + 2^32 * x_2 + 2^16 * x_1 + x_0 + /// Use lookups on x_3, x_2, x_1, x_0 and obtain y_3, y_2, y_1, y_0 + /// [y_3 || y_2 || y_1 || y_0], where y_i is 16 bits large + /// Output y is set such that y = 2^48 * x_3 + 2^32 * x_2 + 2^16 * x_1 + x_0 + #[inline(always)] + fn bar_64(limb: u64) -> u64 { + match LOOKUP_BITS { + 8 => { + let limbl1 = + ((!limb & 0x8080808080808080) >> 7) | ((!limb & 0x7F7F7F7F7F7F7F7F) << 1); // Left rotation by 1 + let limbl2 = + ((limb & 0xC0C0C0C0C0C0C0C0) >> 6) | ((limb & 0x3F3F3F3F3F3F3F3F) << 2); // Left rotation by 2 + let limbl3 = + ((limb & 0xE0E0E0E0E0E0E0E0) >> 5) | ((limb & 0x1F1F1F1F1F1F1F1F) << 3); // Left rotation by 3 + + // y_i = x_i + (1 + x_{i+1}) * x_{i+2} * x_{i+3} + let tmp = limb ^ limbl1 & limbl2 & limbl3; + ((tmp & 0x8080808080808080) >> 7) | ((tmp & 0x7F7F7F7F7F7F7F7F) << 1) + } + 16 => { + let limbl1 = + ((!limb & 0x8000800080008000) >> 15) | ((!limb & 0x7FFF7FFF7FFF7FFF) << 1); // Left rotation by 1 + let limbl2 = + ((limb & 0xC000C000C000C000) >> 14) | ((limb & 0x3FFF3FFF3FFF3FFF) << 2); // Left rotation by 2 + let limbl3 = + ((limb & 0xE000E000E000E000) >> 13) | ((limb & 0x1FFF1FFF1FFF1FFF) << 3); // Left rotation by 3 + + // y_i = x_i + (1 + x_{i+1}) * x_{i+2} * x_{i+3} + let tmp = limb ^ limbl1 & limbl2 & limbl3; + ((tmp & 0x8000800080008000) >> 15) | ((tmp & 0x7FFF7FFF7FFF7FFF) << 1) + // Final rotation + } + _ => { + panic!("Unsupported lookup size"); + } + } + } + + /// Same as `bar` optimized for u128 + #[inline(always)] + fn bar_u128(el: &mut u128) { + let limb = *el as u64; + *el = match LOOKUP_BITS { + 8 => { + let limbl1 = + ((!limb & 0x8080808080808080) >> 7) | ((!limb & 0x7F7F7F7F7F7F7F7F) << 1); // Left rotation by 1 + let limbl2 = + ((limb & 0xC0C0C0C0C0C0C0C0) >> 6) | ((limb & 0x3F3F3F3F3F3F3F3F) << 2); // Left rotation by 2 + let limbl3 = + ((limb & 0xE0E0E0E0E0E0E0E0) >> 5) | ((limb & 0x1F1F1F1F1F1F1F1F) << 3); // Left rotation by 3 + + // y_i = x_i + (1 + x_{i+1}) * x_{i+2} * x_{i+3} + let tmp = limb ^ limbl1 & limbl2 & limbl3; + ((tmp & 0x8080808080808080) >> 7) | ((tmp & 0x7F7F7F7F7F7F7F7F) << 1) + } + 16 => { + let limbl1 = + ((!limb & 0x8000800080008000) >> 15) | ((!limb & 0x7FFF7FFF7FFF7FFF) << 1); // Left rotation by 1 + let limbl2 = + ((limb & 0xC000C000C000C000) >> 14) | ((limb & 0x3FFF3FFF3FFF3FFF) << 2); // Left rotation by 2 + let limbl3 = + ((limb & 0xE000E000E000E000) >> 13) | ((limb & 0x1FFF1FFF1FFF1FFF) << 3); // Left rotation by 3 + + // y_i = x_i + (1 + x_{i+1}) * x_{i+2} * x_{i+3} + let tmp = limb ^ limbl1 & limbl2 & limbl3; + ((tmp & 0x8000800080008000) >> 15) | ((tmp & 0x7FFF7FFF7FFF7FFF) << 1) + // Final rotation + } + _ => { + panic!("Unsupported lookup size"); + } + } as u128; + } + + /// Same as `bars` optimized for u128 + fn bars_u128(state_u128: &mut [u128; SPONGE_WIDTH]) { + Self::bar_u128(&mut state_u128[0]); + Self::bar_u128(&mut state_u128[1]); + Self::bar_u128(&mut state_u128[2]); + Self::bar_u128(&mut state_u128[3]); + } + + /// Compute the "Bricks" component + #[inline(always)] + #[unroll_for_loops] + fn bricks(state: &mut [Self; SPONGE_WIDTH]) { + // Feistel Type-3 + for i in (1..SPONGE_WIDTH).rev() { + let prev = state[i - 1]; + let tmp_square = prev * prev; + state[i] += tmp_square; + } + } + + /// Same as `bricks` optimized for u128 + /// Result is not reduced! + #[unroll_for_loops] + fn bricks_u128(state_u128: &mut [u128; SPONGE_WIDTH]) { + // Feistel Type-3 + // Use "& 0xFFFFFFFFFFFFFFFF" to tell the compiler it is dealing with 64-bit values (save + // some instructions for upper half) + for i in (1..SPONGE_WIDTH).rev() { + let prev = state_u128[i - 1]; + let mut tmp_square = + (prev & 0xFFFFFFFFFFFFFFFF_u128) * (prev & 0xFFFFFFFFFFFFFFFF_u128); + tmp_square = Self::from_noncanonical_u128(tmp_square).to_noncanonical_u64() as u128; + state_u128[i] = + (state_u128[i] & 0xFFFFFFFFFFFFFFFF_u128) + (tmp_square & 0xFFFFFFFFFFFFFFFF_u128); + } + } + + /// Same as `bricks` for field extensions of `Self`. + #[inline(always)] + #[unroll_for_loops] + fn bricks_field, const D: usize>( + state: &mut [F; SPONGE_WIDTH], + ) { + // Feistel Type-3 + // Feistel Type-3 + for i in (1..SPONGE_WIDTH).rev() { + let prev = state[i - 1]; + let tmp_square = prev * prev; + state[i] += tmp_square; + } + } + + /// Recursive version of `bricks`. + #[inline(always)] + #[unroll_for_loops] + fn bricks_circuit( + builder: &mut CircuitBuilder, + state: &mut [ExtensionTarget; SPONGE_WIDTH], + ) where + Self: RichField + Extendable, + { + // Feistel Type-3 + for i in (1..SPONGE_WIDTH).rev() { + let prev = state[i - 1]; + state[i] = builder.mul_add_extension(prev, prev, state[i]) + } + } + + /// Compute the "Concrete" component + #[inline(always)] + #[unroll_for_loops] + fn concrete(state: &mut [Self; SPONGE_WIDTH], round_constants: &[u64; SPONGE_WIDTH]) { + let mut state_tmp = [0u128; SPONGE_WIDTH]; + let mut state_u128 = [0u128; SPONGE_WIDTH]; + for (dst, src) in state_u128.iter_mut().zip(state.iter()) { + *dst = src.to_noncanonical_u64() as u128; + } + concrete_u128_with_tmp_buffer::(&state_u128, round_constants, &mut state_tmp); + for (dst, src) in state.iter_mut().zip(state_tmp.iter()) { + *dst = Self::from_noncanonical_u64(*src as u64) + } + } + + /// Same as `concrete` optimized for u128 + fn concrete_u128(state_u128: &mut [u128; SPONGE_WIDTH], round_constants: &[u64; SPONGE_WIDTH]) { + let mut state_tmp = [0_u128; SPONGE_WIDTH]; + concrete_u128_with_tmp_buffer::(state_u128, round_constants, &mut state_tmp); + state_u128.copy_from_slice(&state_tmp); + } + + /// Same as `concrete` for field extensions of `Self`. + #[inline(always)] + #[unroll_for_loops] + fn concrete_field, const D: usize>( + state: &mut [F; SPONGE_WIDTH], + round_constants: &[u64; SPONGE_WIDTH], + ) { + let mut state_tmp = vec![F::ZERO; SPONGE_WIDTH]; + for row in 0..SPONGE_WIDTH { + for (column, input) in state.iter().enumerate() { + state_tmp[row] += *input * F::from_canonical_u64(Self::MAT_12[row][column]); + } + state_tmp[row] += F::from_canonical_u64(round_constants[row]); + } + state.copy_from_slice(&state_tmp); + } + + /// Recursive version of `concrete`. + #[inline(always)] + #[unroll_for_loops] + fn concrete_circuit( + builder: &mut CircuitBuilder, + state: &mut [ExtensionTarget; SPONGE_WIDTH], + round_constants: &[u64; SPONGE_WIDTH], + ) where + Self: RichField + Extendable, + { + let mut state_tmp = vec![builder.zero_extension(); SPONGE_WIDTH]; + for row in 0..SPONGE_WIDTH { + for (column, input) in state.iter().enumerate() { + state_tmp[row] = builder.mul_const_add_extension( + Self::from_canonical_u64(Self::MAT_12[row][column]), + *input, + state_tmp[row], + ); + } + state_tmp[row] = builder.add_const_extension( + state_tmp[row], + Self::from_canonical_u64(round_constants[row]), + ); + } + state.copy_from_slice(&state_tmp); + } + + /// Full Monolith permutation + #[inline] + fn monolith(input: [Self; SPONGE_WIDTH]) -> [Self; SPONGE_WIDTH] { + let mut state_u128 = [0; SPONGE_WIDTH]; + for (out, inp) in state_u128.iter_mut().zip(input.iter()) { + *out = inp.to_noncanonical_u64() as u128; + } + + Self::concrete_u128(&mut state_u128, &Self::ROUND_CONSTANTS[0]); + for rc in Self::ROUND_CONSTANTS.iter().skip(1) { + Self::bars_u128(&mut state_u128); + Self::bricks_u128(&mut state_u128); + Self::concrete_u128(&mut state_u128, rc); + } + + // Convert back + let mut state_f = [Self::ZERO; SPONGE_WIDTH]; + for (out, inp) in state_f.iter_mut().zip(state_u128.iter()) { + *out = Self::from_canonical_u64(*inp as u64); + } + state_f + } +} + +/// Implementor of Plonky2 `PlonkyPermutation` trait for Monolith +#[derive(Copy, Clone, Default, Debug, PartialEq)] +pub struct MonolithPermutation { + state: [T; SPONGE_WIDTH], +} + +impl Eq for MonolithPermutation {} + +impl AsRef<[T]> for MonolithPermutation { + fn as_ref(&self) -> &[T] { + &self.state + } +} + +trait Permuter: Sized { + fn permute(input: [Self; SPONGE_WIDTH]) -> [Self; SPONGE_WIDTH]; +} + +impl Permuter for F { + fn permute(input: [Self; SPONGE_WIDTH]) -> [Self; SPONGE_WIDTH] { + ::monolith(input) + } +} + +impl Permuter for Target { + fn permute(_input: [Self; SPONGE_WIDTH]) -> [Self; SPONGE_WIDTH] { + panic!("Call `permute_swapped()` instead of `permute()`"); + } +} + +impl PlonkyPermutation + for MonolithPermutation +{ + const RATE: usize = SPONGE_RATE; + const WIDTH: usize = SPONGE_WIDTH; + + fn new>(elts: I) -> Self { + let mut perm = Self { + state: [T::default(); SPONGE_WIDTH], + }; + perm.set_from_iter(elts, 0); + perm + } + + fn set_elt(&mut self, elt: T, idx: usize) { + self.state[idx] = elt; + } + + fn set_from_slice(&mut self, elts: &[T], start_idx: usize) { + let begin = start_idx; + let end = start_idx + elts.len(); + self.state[begin..end].copy_from_slice(elts); + } + + fn set_from_iter>(&mut self, elts: I, start_idx: usize) { + for (s, e) in self.state[start_idx..].iter_mut().zip(elts) { + *s = e; + } + } + + fn permute(&mut self) { + self.state = T::permute(self.state); + } + + fn squeeze(&self) -> &[T] { + &self.state[..Self::RATE] + } +} + +/// Implementor of Plonky2 `Hasher` trait for Monolith +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub struct MonolithHash; +impl Hasher for MonolithHash { + const HASH_SIZE: usize = 4 * 8; + type Hash = HashOut; + type Permutation = MonolithPermutation; + + fn hash_no_pad(input: &[F]) -> Self::Hash { + hash_n_to_hash_no_pad::(input) + } + + fn two_to_one(left: Self::Hash, right: Self::Hash) -> Self::Hash { + compress::(left, right) + } +} + +#[cfg(test)] +pub(crate) mod test { + use crate::monolith_hash::{Monolith, SPONGE_WIDTH}; + use plonky2::field::types::Field; + + pub(crate) fn check_test_vectors( + test_vectors: Vec<([u64; SPONGE_WIDTH], [u64; SPONGE_WIDTH])>, + ) where + F: Monolith, + { + for (input_, expected_output_) in test_vectors.into_iter() { + let mut input = [F::ZERO; SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { + input[i] = F::from_canonical_u64(input_[i]); + } + let output = F::monolith(input); + for i in 0..SPONGE_WIDTH { + let ex_output = F::from_canonical_u64(expected_output_[i]); + assert_eq!(output[i], ex_output); + } + } + } +} diff --git a/plonky2-monolith/src/monolith_hash/monolith_goldilocks.rs b/plonky2-monolith/src/monolith_hash/monolith_goldilocks.rs new file mode 100644 index 0000000..b540584 --- /dev/null +++ b/plonky2-monolith/src/monolith_hash/monolith_goldilocks.rs @@ -0,0 +1,517 @@ +use crate::monolith_hash::monolith_goldilocks::monolith_mds_12::mds_multiply_u128; +use crate::monolith_hash::{Monolith, MonolithHash, LOOKUP_BITS, N_ROUNDS, SPONGE_WIDTH}; +use plonky2::field::extension::quadratic::QuadraticExtension; +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::hash::poseidon::PoseidonHash; +use plonky2::plonk::config::GenericConfig; +use serde::Serialize; + +impl Monolith for GoldilocksField { + const ROUND_CONSTANTS: [[u64; SPONGE_WIDTH]; N_ROUNDS + 1] = match LOOKUP_BITS { + 8 => [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ + 13596126580325903823, + 5676126986831820406, + 11349149288412960427, + 3368797843020733411, + 16240671731749717664, + 9273190757374900239, + 14446552112110239438, + 4033077683985131644, + 4291229347329361293, + 13231607645683636062, + 1383651072186713277, + 8898815177417587567, + ], + [ + 2383619671172821638, + 6065528368924797662, + 16737578966352303081, + 2661700069680749654, + 7414030722730336790, + 18124970299993404776, + 9169923000283400738, + 15832813151034110977, + 16245117847613094506, + 11056181639108379773, + 10546400734398052938, + 8443860941261719174, + ], + [ + 15799082741422909885, + 13421235861052008152, + 15448208253823605561, + 2540286744040770964, + 2895626806801935918, + 8644593510196221619, + 17722491003064835823, + 5166255496419771636, + 1015740739405252346, + 4400043467547597488, + 5176473243271652644, + 4517904634837939508, + ], + [ + 18341030605319882173, + 13366339881666916534, + 6291492342503367536, + 10004214885638819819, + 4748655089269860551, + 1520762444865670308, + 8393589389936386108, + 11025183333304586284, + 5993305003203422738, + 458912836931247573, + 5947003897778655410, + 17184667486285295106, + ], + [ + 15710528677110011358, + 8929476121507374707, + 2351989866172789037, + 11264145846854799752, + 14924075362538455764, + 10107004551857451916, + 18325221206052792232, + 16751515052585522105, + 15305034267720085905, + 15639149412312342017, + 14624541102106656564, + 3542311898554959098, + ], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + 16 => [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ + 18336847912085310782, + 16981085523750439062, + 13429031554613510028, + 14626146163475314696, + 17132599202993726423, + 8006190003318006507, + 11343032213505247196, + 14124666955091711556, + 8430380888588022602, + 8028059853581205264, + 10576927460643802925, + 264807431271531499, + ], + [ + 4974395136075591328, + 12767804748363387455, + 4282984340606842818, + 9962032970357721094, + 13290063373589851073, + 682582873026109162, + 1443405731716023143, + 1102365195228642031, + 2045097484032658744, + 4705239685543555952, + 7749631247106030298, + 14498144818552307386, + ], + [ + 2422278540391021322, + 16279967701033470233, + 11928233299971145130, + 289434792182172450, + 9247027096240775287, + 13564504933984041357, + 13716745789926357653, + 17062841883145120930, + 4787227470665224131, + 3941766098336857538, + 10415914353862079098, + 2031314485617648836, + ], + [ + 15757165366981665927, + 5316332562976837179, + 6408794885240907199, + 15433272772010162147, + 16177208255639089922, + 6438767259788073242, + 1850299052911296965, + 12036975040590254229, + 14345891531575426146, + 7475247528756702227, + 3952963486672887438, + 15765121003485081487, + ], + [ + 8288959343482523513, + 6774706297840606862, + 15381728973932837801, + 15052040954696745676, + 9925792545634777672, + 9264032288608603069, + 11473431200717914600, + 2655107155645324988, + 8397223040566002342, + 9234186621285090301, + 1463633689352888362, + 18441834386923465669, + ], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + _ => panic!("Unsupported lookup size"), + }; + + const MAT_12: [[u64; SPONGE_WIDTH]; SPONGE_WIDTH] = [ + [7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8], + [8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21], + [21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22], + [22, 21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6], + [6, 22, 21, 8, 7, 23, 8, 26, 13, 10, 9, 7], + [7, 6, 22, 21, 8, 7, 23, 8, 26, 13, 10, 9], + [9, 7, 6, 22, 21, 8, 7, 23, 8, 26, 13, 10], + [10, 9, 7, 6, 22, 21, 8, 7, 23, 8, 26, 13], + [13, 10, 9, 7, 6, 22, 21, 8, 7, 23, 8, 26], + [26, 13, 10, 9, 7, 6, 22, 21, 8, 7, 23, 8], + [8, 26, 13, 10, 9, 7, 6, 22, 21, 8, 7, 23], + [23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8, 7], + ]; + + #[cfg(feature = "default-sponge-params")] + fn concrete_u128(state_u128: &mut [u128; SPONGE_WIDTH], round_constants: &[u64; SPONGE_WIDTH]) { + mds_multiply_u128(state_u128, round_constants) + } +} + +mod monolith_mds_12 { + use crate::monolith_hash::split; + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::types::Field; + + /// This module contains helper functions as well as constants used to perform a 12x12 vector-matrix + /// multiplication. The special form of our MDS matrix i.e. being circulant, allows us to reduce + /// the vector-matrix multiplication to a Hadamard product of two vectors in "frequency domain". + /// This follows from the simple fact that every circulant matrix has the columns of the discrete + /// Fourier transform matrix as orthogonal eigenvectors. + /// The implementation also avoids the use of 3-point FFTs, and 3-point iFFTs, and substitutes that + /// with explicit expressions. It also avoids, due to the form of our matrix in the frequency domain, + /// divisions by 2 and repeated modular reductions. This is because of our explicit choice of + /// an MDS matrix that has small powers of 2 entries in frequency domain. + /// The following implementation has benefited greatly from the discussions and insights of + /// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero. + /// The circulant matrix is identified by its first row: [7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8]. + + // MDS matrix in frequency domain. + // More precisely, this is the output of the three 4-point (real) FFTs of the first column of + // the MDS matrix i.e. just before the multiplication with the appropriate twiddle factors + // and application of the final four 3-point FFT in order to get the full 12-point FFT. + // The entries have been scaled appropriately in order to avoid divisions by 2 in iFFT2 and iFFT4. + // The code to generate the matrix in frequency domain is based on an adaptation of a code, to generate + // MDS matrices efficiently in original domain, that was developed by the Polygon Zero team. + const MDS_FREQ_BLOCK_ONE: [i64; 3] = [16, 8, 16]; + const MDS_FREQ_BLOCK_TWO: [(i64, i64); 3] = [(-1, 2), (-1, 1), (4, 8)]; + const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-8, 1, 1]; + + pub(crate) fn mds_multiply_u128(state: &mut [u128; 12], round_constants: &[u64; 12]) { + // Using the linearity of the operations we can split the state into a low||high decomposition + // and operate on each with no overflow and then combine/reduce the result to a field element. + let mut state_l = [0u64; 12]; + let mut state_h = [0u64; 12]; + + for r in 0..12 { + let s = state[r]; + state_h[r] = (s >> 32) as u64; + state_l[r] = (s as u32) as u64; + } + + let state_h = mds_multiply_freq(state_h); + let state_l = mds_multiply_freq(state_l); + + for r in 0..12 { + // Both have less than 40 bits + state[r] = state_l[r] as u128 + ((state_h[r] as u128) << 32); + state[r] += round_constants[r] as u128; + state[r] = GoldilocksField::from_noncanonical_u96(split(state[r])).0 as u128; + } + } + + // We use split 3 x 4 FFT transform in order to transform our vectors into the frequency domain. + #[inline(always)] + pub(crate) fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] { + let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = state; + + let (u0, u1, u2) = fft4_real([s0, s3, s6, s9]); + let (u4, u5, u6) = fft4_real([s1, s4, s7, s10]); + let (u8, u9, u10) = fft4_real([s2, s5, s8, s11]); + + // This where the multiplication in frequency domain is done. More precisely, and with + // the appropriate permuations in between, the sequence of + // 3-point FFTs --> multiplication by twiddle factors --> Hadamard multiplication --> + // 3 point iFFTs --> multiplication by (inverse) twiddle factors + // is "squashed" into one step composed of the functions "block1", "block2" and "block3". + // The expressions in the aforementioned functions are the result of explicit computations + // combined with the Karatsuba trick for the multiplication of Complex numbers. + + let [v0, v4, v8] = block1([u0, u4, u8], MDS_FREQ_BLOCK_ONE); + let [v1, v5, v9] = block2([u1, u5, u9], MDS_FREQ_BLOCK_TWO); + let [v2, v6, v10] = block3([u2, u6, u10], MDS_FREQ_BLOCK_THREE); + // The 4th block is not computed as it is similar to the 2nd one, up to complex conjugation, + // and is, due to the use of the real FFT and iFFT, redundant. + + let [s0, s3, s6, s9] = ifft4_real_unreduced((v0, v1, v2)); + let [s1, s4, s7, s10] = ifft4_real_unreduced((v4, v5, v6)); + let [s2, s5, s8, s11] = ifft4_real_unreduced((v8, v9, v10)); + + [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] + } + + #[inline(always)] + fn block1(x: [i64; 3], y: [i64; 3]) -> [i64; 3] { + let [x0, x1, x2] = x; + let [y0, y1, y2] = y; + let z0 = x0 * y0 + x1 * y2 + x2 * y1; + let z1 = x0 * y1 + x1 * y0 + x2 * y2; + let z2 = x0 * y2 + x1 * y1 + x2 * y0; + + [z0, z1, z2] + } + + #[inline(always)] + fn block2(x: [(i64, i64); 3], y: [(i64, i64); 3]) -> [(i64, i64); 3] { + let [(x0r, x0i), (x1r, x1i), (x2r, x2i)] = x; + let [(y0r, y0i), (y1r, y1i), (y2r, y2i)] = y; + let x0s = x0r + x0i; + let x1s = x1r + x1i; + let x2s = x2r + x2i; + let y0s = y0r + y0i; + let y1s = y1r + y1i; + let y2s = y2r + y2i; + + // Compute x0​y0 ​− ix1​y2​ − ix2​y1​ using Karatsuba for complex numbers multiplication + let m0 = (x0r * y0r, x0i * y0i); + let m1 = (x1r * y2r, x1i * y2i); + let m2 = (x2r * y1r, x2i * y1i); + let z0r = (m0.0 - m0.1) + (x1s * y2s - m1.0 - m1.1) + (x2s * y1s - m2.0 - m2.1); + let z0i = (x0s * y0s - m0.0 - m0.1) + (-m1.0 + m1.1) + (-m2.0 + m2.1); + let z0 = (z0r, z0i); + + // Compute x0​y1​ + x1​y0​ − ix2​y2 using Karatsuba for complex numbers multiplication + let m0 = (x0r * y1r, x0i * y1i); + let m1 = (x1r * y0r, x1i * y0i); + let m2 = (x2r * y2r, x2i * y2i); + let z1r = (m0.0 - m0.1) + (m1.0 - m1.1) + (x2s * y2s - m2.0 - m2.1); + let z1i = (x0s * y1s - m0.0 - m0.1) + (x1s * y0s - m1.0 - m1.1) + (-m2.0 + m2.1); + let z1 = (z1r, z1i); + + // Compute x0​y2​ + x1​y1 ​+ x2​y0​ using Karatsuba for complex numbers multiplication + let m0 = (x0r * y2r, x0i * y2i); + let m1 = (x1r * y1r, x1i * y1i); + let m2 = (x2r * y0r, x2i * y0i); + let z2r = (m0.0 - m0.1) + (m1.0 - m1.1) + (m2.0 - m2.1); + let z2i = (x0s * y2s - m0.0 - m0.1) + (x1s * y1s - m1.0 - m1.1) + (x2s * y0s - m2.0 - m2.1); + let z2 = (z2r, z2i); + + [z0, z1, z2] + } + + #[inline(always)] + fn block3(x: [i64; 3], y: [i64; 3]) -> [i64; 3] { + let [x0, x1, x2] = x; + let [y0, y1, y2] = y; + let z0 = x0 * y0 - x1 * y2 - x2 * y1; + let z1 = x0 * y1 + x1 * y0 - x2 * y2; + let z2 = x0 * y2 + x1 * y1 + x2 * y0; + + [z0, z1, z2] + } + + /// Real 2-FFT over u64 integers. + #[inline(always)] + fn fft2_real(x: [u64; 2]) -> [i64; 2] { + [(x[0] as i64 + x[1] as i64), (x[0] as i64 - x[1] as i64)] + } + + /// Real 2-iFFT over u64 integers. + /// Division by two to complete the inverse FFT is expected to be performed ***outside*** of this function. + #[inline(always)] + fn ifft2_real_unreduced(y: [i64; 2]) -> [u64; 2] { + [(y[0] + y[1]) as u64, (y[0] - y[1]) as u64] + } + + /// Real 4-FFT over u64 integers. + #[inline(always)] + fn fft4_real(x: [u64; 4]) -> (i64, (i64, i64), i64) { + let [z0, z2] = fft2_real([x[0], x[2]]); + let [z1, z3] = fft2_real([x[1], x[3]]); + let y0 = z0 + z1; + let y1 = (z2, -z3); + let y2 = z0 - z1; + (y0, y1, y2) + } + + /// Real 4-iFFT over u64 integers. + /// Division by four to complete the inverse FFT is expected to be performed ***outside*** of this function. + #[inline(always)] + fn ifft4_real_unreduced(y: (i64, (i64, i64), i64)) -> [u64; 4] { + let z0 = y.0 + y.2; + let z1 = y.0 - y.2; + let z2 = y.1 .0; + let z3 = -y.1 .1; + + let [x0, x2] = ifft2_real_unreduced([z0, z2]); + let [x1, x3] = ifft2_real_unreduced([z1, z3]); + + [x0, x1, x2, x3] + } +} + +/// Configuration using Monolith over the Goldilocks field. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Serialize)] +pub struct MonolithGoldilocksConfig; +impl GenericConfig<2> for MonolithGoldilocksConfig { + type F = GoldilocksField; + type FE = QuadraticExtension; + type Hasher = MonolithHash; + type InnerHasher = PoseidonHash; +} + +#[cfg(test)] +mod tests { + use crate::gates::gadget::tests::{ + prove_circuit_with_hash, recursive_proof, test_monolith_hash_circuit, + }; + use crate::gates::generate_config_for_monolith_gate; + use crate::monolith_hash::monolith_goldilocks::MonolithGoldilocksConfig; + use crate::monolith_hash::test::check_test_vectors; + use crate::monolith_hash::{Monolith, MonolithHash, LOOKUP_BITS}; + use plonky2::field::extension::Extendable; + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::hash::hash_types::RichField; + use plonky2::hash::poseidon::PoseidonHash; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{ + AlgebraicHasher, GenericConfig, Hasher, PoseidonGoldilocksConfig, + }; + use rstest::rstest; + use serial_test::serial; + use std::marker::PhantomData; + + #[test] + fn test_vectors() { + // Test inputs are: + // 1. 0..WIDTH-1 + + #[rustfmt::skip] + let test_vectors12: Vec<([u64; 12], [u64; 12])> = match LOOKUP_BITS { + 8 => vec![ + ([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, ], + [5867581605548782913, 588867029099903233, 6043817495575026667, 805786589926590032, 9919982299747097782, 6718641691835914685, 7951881005429661950, 15453177927755089358, 974633365445157727, 9654662171963364206, 6281307445101925412, 13745376999934453119]), + ], + 16 => vec![ + ([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, ], + [15270549627416999494, 2608801733076195295, 2511564300649802419, 14351608014180687564, 4101801939676807387, 234091379199311770, 3560400203616478913, 17913168886441793528, 7247432905090441163, 667535998170608897, 5848119428178849609, 7505720212650520546]), + ], + _ => panic!("unsupported lookup size"), + }; + + check_test_vectors::(test_vectors12); + } + + // helper struct employed to bind a Hasher implementation `H` with the circuit configuration to + // be employed to build the circuit when such Hasher `H` is employed in the circuit + struct HasherConfig< + const D: usize, + F: RichField + Monolith + Extendable, + H: Hasher + AlgebraicHasher, + > { + field: PhantomData, + hasher: PhantomData, + circuit_config: CircuitConfig, + } + + #[rstest] + #[serial] + fn test_circuit_with_hash_functions< + F: RichField + Monolith + Extendable, + C: GenericConfig, + H: Hasher + AlgebraicHasher, + const D: usize, + >( + #[values(PoseidonGoldilocksConfig, MonolithGoldilocksConfig)] _c: C, + #[values(HasherConfig::<2, GoldilocksField, PoseidonHash> { + field: PhantomData::default(), + hasher: PhantomData::default(), + circuit_config: CircuitConfig::standard_recursion_config(), + }, HasherConfig::<2, GoldilocksField , MonolithHash> { + field: PhantomData::default(), + hasher: PhantomData::default(), + circuit_config: generate_config_for_monolith_gate::(), + })] + config: HasherConfig, + ) { + let _ = env_logger::builder().is_test(true).try_init(); + + let (cd, proof) = + prove_circuit_with_hash::(config.circuit_config, 4096, true).unwrap(); + + cd.verify(proof).unwrap() + } + // helper struct employed to bind a GenericConfig `C` with the circuit configuration + // to be employed to build the circuit when such `C` is employed in the circuit + struct HashConfig> { + gen_config: PhantomData, + circuit_config: CircuitConfig, + } + + #[rstest] + #[serial] + fn test_recursive_circuit_with_hash_functions< + F: RichField + Monolith + Extendable, + C: GenericConfig, + InnerC: GenericConfig, + const D: usize, + >( + #[values(PoseidonGoldilocksConfig, MonolithGoldilocksConfig)] _c: C, + #[values(HashConfig::<2, PoseidonGoldilocksConfig> { + gen_config: PhantomData::default(), + circuit_config: CircuitConfig::standard_recursion_config(), + }, HashConfig::<2, MonolithGoldilocksConfig> { + gen_config: PhantomData::default(), + circuit_config: generate_config_for_monolith_gate::(), + })] + inner_conf: HashConfig, + ) where + C::Hasher: AlgebraicHasher, + InnerC::Hasher: AlgebraicHasher, + { + let _ = env_logger::builder().is_test(true).try_init(); + + let (cd, proof) = prove_circuit_with_hash::( + CircuitConfig::standard_recursion_config(), + 2048, + false, + ) + .unwrap(); + + println!("base proof generated"); + + println!("base circuit size: {}", cd.common.degree_bits()); + + let (rec_cd, rec_proof) = + recursive_proof::(proof, &cd, &inner_conf.circuit_config).unwrap(); + + println!( + "recursive proof generated, recursion circuit size: {}", + rec_cd.common.degree_bits() + ); + + rec_cd.verify(rec_proof).unwrap(); + } + + #[test] + fn test_monolith_hash() { + const D: usize = 2; + type C = MonolithGoldilocksConfig; + type F = >::F; + let config = generate_config_for_monolith_gate::(); + let _ = env_logger::builder().is_test(true).try_init(); + test_monolith_hash_circuit::(config) + } +} diff --git a/plonky2-monolith/tests/integration.rs b/plonky2-monolith/tests/integration.rs new file mode 100644 index 0000000..8d219ab --- /dev/null +++ b/plonky2-monolith/tests/integration.rs @@ -0,0 +1,69 @@ +use core::ops::Mul; +use plonky2::field::{goldilocks_field::GoldilocksField, types::Sample}; +use plonky2::hash::hash_types::NUM_HASH_OUT_ELTS; +use plonky2::iop::witness::{PartialWitness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::config::Hasher; +use plonky2_monolith::monolith_hash::MonolithHash; +use plonky2_monolith::{ + gates::generate_config_for_monolith_gate, + monolith_hash::monolith_goldilocks::MonolithGoldilocksConfig, +}; + +#[test] +fn test_circuit_with_monolith() { + type F = GoldilocksField; + const D: usize = 2; + type H = MonolithHash; + type C = MonolithGoldilocksConfig; + const NUM_OPS: usize = 1024; + let config = generate_config_for_monolith_gate::(); + let mut builder = CircuitBuilder::::new(config); + let init_t = builder.add_virtual_public_input(); + let mut res_t = builder.add_virtual_target(); + builder.connect(init_t, res_t); + let hash_targets = (0..NUM_HASH_OUT_ELTS - 1) + .map(|_| builder.add_virtual_target()) + .collect::>(); + for _ in 0..NUM_OPS { + res_t = builder.mul(res_t, res_t); + let mut to_be_hashed_elements = vec![res_t]; + to_be_hashed_elements.extend_from_slice(hash_targets.as_slice()); + res_t = builder.hash_or_noop::(to_be_hashed_elements).elements[0] + } + let out_t = builder.add_virtual_public_input(); + let is_eq_t = builder.is_equal(out_t, res_t); + builder.assert_one(is_eq_t.target); + + let data = builder.build::(); + + let mut pw = PartialWitness::::new(); + let input = F::rand(); + pw.set_target(init_t, input); + + let input_hash_elements = hash_targets + .iter() + .map(|&hash_t| { + let elem = F::rand(); + pw.set_target(hash_t, elem); + elem + }) + .collect::>(); + + let mut res = input; + for _ in 0..NUM_OPS { + res = res.mul(res); + let mut to_be_hashed_elements = vec![res]; + to_be_hashed_elements.extend_from_slice(input_hash_elements.as_slice()); + res = H::hash_or_noop(to_be_hashed_elements.as_slice()).elements[0] + } + + pw.set_target(out_t, res); + + let proof = data.prove(pw).unwrap(); + + assert_eq!(proof.public_inputs[0], input); + assert_eq!(proof.public_inputs[1], res); + + data.verify(proof).unwrap(); +}