commit 8c1489b273043d4ae8bb310bf2514900d0f02155 Author: M Alghazwi Date: Mon Oct 7 10:36:11 2024 +0200 init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a3547b9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +#IDE Related +.idea + +# Cargo build +/target +Cargo.lock + +# Profile-guided optimization +/tmp +pgo-data.profdata + +# MacOS nuisances +.DS_Store diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..1185acd --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,6 @@ +All crates of this repo are licensed under either of + +* Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) +* MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) + +at your option. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..cec5c1a --- /dev/null +++ b/README.md @@ -0,0 +1,13 @@ +Proof Aggregation +================================ + +This repository contains all work related to proof aggregation (currently only local proof aggregation). + +Repository organization +----------------- + +- [`plonly2_poseidon2`](./plonky2_poseidon2) is the crate for plonky2 which supports the poseidon2 hash function. + +- [`codex-plonky2-circuits`](./codex-plonky2-circuits) contains the codex proof circuits tailored specifically for plonky2. These circuits have the functionality as those in [**here**](https://github.com/codex-storage/codex-storage-proofs-circuits) + +**WARNING**: This repository contains work-in-progress prototypes, and has not received careful code review. It is NOT ready for production use. diff --git a/codex-plonky2-circuits/.gitignore b/codex-plonky2-circuits/.gitignore new file mode 100644 index 0000000..a3547b9 --- /dev/null +++ b/codex-plonky2-circuits/.gitignore @@ -0,0 +1,13 @@ +#IDE Related +.idea + +# Cargo build +/target +Cargo.lock + +# Profile-guided optimization +/tmp +pgo-data.profdata + +# MacOS nuisances +.DS_Store diff --git a/codex-plonky2-circuits/Cargo.toml b/codex-plonky2-circuits/Cargo.toml new file mode 100644 index 0000000..d68fee0 --- /dev/null +++ b/codex-plonky2-circuits/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "codex-plonky2-circuits" +description = "Codex storage proofs circuits for Plonky2" +authors = ["Mohammed Alghazwi "] +readme = "README.md" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = { version = "1.0.89" } +unroll = { version = "0.1.5", default-features = false } +serde = { version = "1.0.210" , features = ["rc"] } +serde_json = { version = "1.0" } +plonky2 = { version = "0.2.2" } +plonky2_field = { version = "0.2.2", default-features = false } +plonky2_poseidon2 = { path = "../plonky2_poseidon2" } +itertools = { version = "0.12.1", default-features = false } +plonky2_maybe_rayon = { version = "0.2.0", default-features = false } +rand = "0.8.5" + +[dev-dependencies] +criterion = { version = "0.5.1", default-features = false } +tynm = { version = "0.1.6", default-features = false } diff --git a/codex-plonky2-circuits/README.md b/codex-plonky2-circuits/README.md new file mode 100644 index 0000000..1058a84 --- /dev/null +++ b/codex-plonky2-circuits/README.md @@ -0,0 +1,7 @@ +# Codex Plonky2 Circuits +WARNING: This is a work-in-progress prototype, and has not received careful code review. This implementation is NOT ready for production use. + +This crate is an implementation of the [codex storage proofs circuits](https://github.com/codex-storage/codex-storage-proofs-circuits) for the plonky2 proof system. + +## Benchmarks +TODO ... \ No newline at end of file diff --git a/codex-plonky2-circuits/src/circuits/merkle_tree_circuit.rs b/codex-plonky2-circuits/src/circuits/merkle_tree_circuit.rs new file mode 100644 index 0000000..f1a0da9 --- /dev/null +++ b/codex-plonky2-circuits/src/circuits/merkle_tree_circuit.rs @@ -0,0 +1,448 @@ +use anyhow::Result; +use plonky2::field::extension::Extendable; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::hashing::hash_n_to_m_no_pad; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::iop::witness::{PartialWitness, WitnessWrite, Witness}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, VerifierCircuitData}; +use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher, GenericHashOut}; +use plonky2::plonk::proof::ProofWithPublicInputs; +use std::marker::PhantomData; +use itertools::Itertools; + +use crate::merkle_tree::MerkleTree; +use plonky2::hash::poseidon::PoseidonHash; + +use plonky2::hash::hash_types::{HashOutTarget, MerkleCapTarget, NUM_HASH_OUT_ELTS}; +use crate::merkle_tree::{MerkleProof, MerkleProofTarget}; +use plonky2_poseidon2::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Hash}; + +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::plonk::config::PoseidonGoldilocksConfig; +use plonky2::plonk::proof::Proof; + +use plonky2::hash::hashing::PlonkyPermutation; +use plonky2::plonk::circuit_data::VerifierCircuitTarget; +use crate::merkle_tree::MerkleCap; + +// size of leaf data (in number of field elements) +pub const LEAF_LEN: usize = 4; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct MerkleTreeTargets< + F: RichField + Extendable + Poseidon2, + C: GenericConfig, + const D: usize, + H: Hasher + AlgebraicHasher, +> { + pub proof_target: MerkleProofTarget, + pub cap_target: MerkleCapTarget, + pub leaf: Vec, + pub leaf_index_target: Target, + _phantom: PhantomData<(C,H)>, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct MerkleTreeCircuit< + F: RichField + Extendable + Poseidon2, + C: GenericConfig, + const D: usize, + H: Hasher + AlgebraicHasher, +> { + pub tree: MerkleTree, + pub _phantom: PhantomData, +} + +impl< + F: RichField + Extendable + Poseidon2, + C: GenericConfig, + const D: usize, + H: Hasher + AlgebraicHasher, +> MerkleTreeCircuit{ + + pub fn tree_height(&self) -> usize { + self.tree.leaves.len().trailing_zeros() as usize + } + + // build the circuit and returns the circuit data + // note, this fn generate circuit data with + pub fn build_circuit(&mut self, builder: &mut CircuitBuilder::) -> MerkleTreeTargets{ + + let proof_t = MerkleProofTarget { + siblings: builder.add_virtual_hashes(self.tree_height()-self.tree.cap.height()), + }; + + let cap_t = builder.add_virtual_cap(self.tree.cap.height()); + + let leaf_index_t = builder.add_virtual_target(); + + let leaf_index_bits = builder.split_le(leaf_index_t, self.tree_height()); + + // NOTE: takes the length from const LEAF_LEN and assume all lengths are the same + let leaf_t: [Target; LEAF_LEN] = builder.add_virtual_targets(LEAF_LEN).try_into().unwrap(); + + let zero = builder.zero(); + // let mut mt = MT(self.tree.clone()); + self.verify_merkle_proof_to_cap_circuit( + builder, leaf_t.to_vec(), &leaf_index_bits, &cap_t, &proof_t, + ); + + MerkleTreeTargets{ + // depth: 0, + // cap_height: 0, + proof_target: proof_t, + cap_target: cap_t, + leaf: leaf_t.to_vec(), + leaf_index_target: leaf_index_t, + _phantom: Default::default(), + } + } + + pub fn fill_targets( + &self, + pw: &mut PartialWitness, + // leaf_data: Vec, + leaf_index: usize, + targets: MerkleTreeTargets, + ) { + let proof = self.tree.prove(leaf_index); + + for i in 0..proof.siblings.len() { + pw.set_hash_target(targets.proof_target.siblings[i], proof.siblings[i]); + } + + // set cap target manually + // pw.set_cap_target(&cap_t, &tree.cap); + for (ht, h) in targets.cap_target.0.iter().zip(&self.tree.cap.0) { + pw.set_hash_target(*ht, *h); + } + + pw.set_target( + targets.leaf_index_target, + F::from_canonical_usize(leaf_index), + ); + + for j in 0..targets.leaf.len() { + pw.set_target(targets.leaf[j], self.tree.leaves[leaf_index][j]); + } + + } + + pub fn prove( + &self, + data: CircuitData, + pw: PartialWitness + ) -> Result> { + let proof = data.prove(pw); + return proof + } + + // function to automate build and prove, useful for quick testing + pub fn build_and_prove( + &mut self, + // builder: &mut CircuitBuilder::, + config: CircuitConfig, + // pw: &mut PartialWitness, + leaf_index: usize, + // data: CircuitData, + ) -> Result<(CircuitData,ProofWithPublicInputs)> { + let mut builder = CircuitBuilder::::new(config); + let mut pw = PartialWitness::new(); + // merkle proof + let merkle_proof = self.tree.prove(leaf_index); + let proof_t = MerkleProofTarget { + siblings: builder.add_virtual_hashes(merkle_proof.siblings.len()), + }; + + for i in 0..merkle_proof.siblings.len() { + pw.set_hash_target(proof_t.siblings[i], merkle_proof.siblings[i]); + } + + // merkle cap target + let cap_t = builder.add_virtual_cap(self.tree.cap.height()); + // set cap target manually + // pw.set_cap_target(&cap_t, &tree.cap); + for (ht, h) in cap_t.0.iter().zip(&self.tree.cap.0) { + pw.set_hash_target(*ht, *h); + } + + // leaf index target + let leaf_index_t = builder.constant(F::from_canonical_usize(leaf_index)); + let leaf_index_bits = builder.split_le(leaf_index_t, self.tree_height()); + + // leaf targets + // NOTE: takes the length from const LEAF_LEN and assume all lengths are the same + // let leaf_t = builder.add_virtual_targets(LEAF_LEN); + let leaf_t = builder.add_virtual_targets(self.tree.leaves[leaf_index].len()); + for j in 0..leaf_t.len() { + pw.set_target(leaf_t[j], self.tree.leaves[leaf_index][j]); + } + + // let mut mt = MT(self.tree.clone()); + self.verify_merkle_proof_to_cap_circuit( + &mut builder, leaf_t.to_vec(), &leaf_index_bits, &cap_t, &proof_t, + ); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + Ok((data, proof)) + } + + pub fn verify( + &self, + verifier_data: &VerifierCircuitData, + public_inputs: Vec, + proof: Proof + ) -> Result<()> { + verifier_data.verify(ProofWithPublicInputs { + proof, + public_inputs, + }) + } +} + +impl + Poseidon2, const D: usize, C: GenericConfig, H: Hasher + AlgebraicHasher,> MerkleTreeCircuit { + + pub fn verify_merkle_proof_circuit( + &mut self, + builder: &mut CircuitBuilder, + leaf_data: Vec, + leaf_index_bits: &[BoolTarget], + merkle_root: HashOutTarget, + proof: &MerkleProofTarget, + ) { + let merkle_cap = MerkleCapTarget(vec![merkle_root]); + self.verify_merkle_proof_to_cap_circuit(builder, leaf_data, leaf_index_bits, &merkle_cap, proof); + } + + pub fn verify_merkle_proof_to_cap_circuit( + &mut self, + builder: &mut CircuitBuilder, + leaf_data: Vec, + leaf_index_bits: &[BoolTarget], + merkle_cap: &MerkleCapTarget, + proof: &MerkleProofTarget, + ) { + let cap_index = builder.le_sum(leaf_index_bits[proof.siblings.len()..].iter().copied()); + self.verify_merkle_proof_to_cap_with_cap_index_circuit( + builder, + leaf_data, + leaf_index_bits, + cap_index, + merkle_cap, + proof, + ); + } + + pub fn verify_merkle_proof_to_cap_with_cap_index_circuit( + &mut self, + builder: &mut CircuitBuilder, + leaf_data: Vec, + leaf_index_bits: &[BoolTarget], + cap_index: Target, + merkle_cap: &MerkleCapTarget, + proof: &MerkleProofTarget, + ) { + debug_assert!(H::AlgebraicPermutation::RATE >= NUM_HASH_OUT_ELTS); + + let zero = builder.zero(); + let mut state: HashOutTarget = builder.hash_or_noop::(leaf_data); + debug_assert_eq!(state.elements.len(), NUM_HASH_OUT_ELTS); + + for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { + debug_assert_eq!(sibling.elements.len(), NUM_HASH_OUT_ELTS); + + let mut perm_inputs = H::AlgebraicPermutation::default(); + perm_inputs.set_from_slice(&state.elements, 0); + perm_inputs.set_from_slice(&sibling.elements, NUM_HASH_OUT_ELTS); + // Ensure the rest of the state, if any, is zero: + perm_inputs.set_from_iter(core::iter::repeat(zero), 2 * NUM_HASH_OUT_ELTS); + // let perm_outs = builder.permute_swapped::(perm_inputs, bit); + let perm_outs = H::permute_swapped(perm_inputs, bit, builder); + let hash_outs = perm_outs.squeeze()[0..NUM_HASH_OUT_ELTS] + .try_into() + .unwrap(); + state = HashOutTarget { + elements: hash_outs, + }; + } + + for i in 0..NUM_HASH_OUT_ELTS { + let result = builder.random_access( + cap_index, + merkle_cap.0.iter().map(|h| h.elements[i]).collect(), + ); + builder.connect(result, state.elements[i]); + } + } + + pub fn verify_batch_merkle_proof_to_cap_with_cap_index_circuit( + &mut self, + builder: &mut CircuitBuilder, + leaf_data: &[Vec], + leaf_heights: &[usize], + leaf_index_bits: &[BoolTarget], + cap_index: Target, + merkle_cap: &MerkleCapTarget, + proof: &MerkleProofTarget, + ) { + debug_assert!(H::AlgebraicPermutation::RATE >= NUM_HASH_OUT_ELTS); + + let zero = builder.zero(); + let mut state: HashOutTarget = builder.hash_or_noop::(leaf_data[0].clone()); + debug_assert_eq!(state.elements.len(), NUM_HASH_OUT_ELTS); + + let mut current_height = leaf_heights[0]; + let mut leaf_data_index = 1; + for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { + debug_assert_eq!(sibling.elements.len(), NUM_HASH_OUT_ELTS); + + let mut perm_inputs = H::AlgebraicPermutation::default(); + perm_inputs.set_from_slice(&state.elements, 0); + perm_inputs.set_from_slice(&sibling.elements, NUM_HASH_OUT_ELTS); + // Ensure the rest of the state, if any, is zero: + perm_inputs.set_from_iter(core::iter::repeat(zero), 2 * NUM_HASH_OUT_ELTS); + // let perm_outs = builder.permute_swapped::(perm_inputs, bit); + let perm_outs = H::permute_swapped(perm_inputs, bit, builder); + let hash_outs = perm_outs.squeeze()[0..NUM_HASH_OUT_ELTS] + .try_into() + .unwrap(); + state = HashOutTarget { + elements: hash_outs, + }; + current_height -= 1; + + if leaf_data_index < leaf_heights.len() + && current_height == leaf_heights[leaf_data_index] + { + let mut new_leaves = state.elements.to_vec(); + new_leaves.extend_from_slice(&leaf_data[leaf_data_index]); + state = builder.hash_or_noop::(new_leaves); + + leaf_data_index += 1; + } + } + + for i in 0..NUM_HASH_OUT_ELTS { + let result = builder.random_access( + cap_index, + merkle_cap.0.iter().map(|h| h.elements[i]).collect(), + ); + builder.connect(result, state.elements[i]); + } + } + + pub fn connect_hashes(&mut self, builder: &mut CircuitBuilder, x: HashOutTarget, y: HashOutTarget) { + for i in 0..NUM_HASH_OUT_ELTS { + builder.connect(x.elements[i], y.elements[i]); + } + } + + pub fn connect_merkle_caps(&mut self, builder: &mut CircuitBuilder, x: &MerkleCapTarget, y: &MerkleCapTarget) { + for (h0, h1) in x.0.iter().zip_eq(&y.0) { + self.connect_hashes(builder, *h0, *h1); + } + } + + pub fn connect_verifier_data(&mut self, builder: &mut CircuitBuilder, x: &VerifierCircuitTarget, y: &VerifierCircuitTarget) { + self.connect_merkle_caps(builder, &x.constants_sigmas_cap, &y.constants_sigmas_cap); + self.connect_hashes(builder, x.circuit_digest, y.circuit_digest); + } +} + +#[cfg(test)] +pub mod tests { + use std::time::Instant; + use rand::rngs::OsRng; + use rand::Rng; + + use super::*; + use plonky2::field::types::Field; + use crate::merkle_tree::MerkleTree; + use plonky2::iop::witness::{PartialWitness, WitnessWrite}; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + pub fn random_data(n: usize, k: usize) -> Vec> { + (0..n).map(|_| F::rand_vec(k)).collect() + } + + #[test] + fn test_merkle_circuit() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + type H = PoseidonHash; + + // create Merkle tree + let log_n = 8; + let n = 1 << log_n; + let cap_height = 1; + let leaves = random_data::(n, LEAF_LEN); + let tree = MerkleTree::>::Hasher>::new(leaves, cap_height); + + // ---- prover zone ---- + // Build and prove + let start_build = Instant::now(); + let mut mt_circuit = MerkleTreeCircuit::{ tree: tree.clone(), _phantom: Default::default() }; + let leaf_index: usize = OsRng.gen_range(0..n); + let config = CircuitConfig::standard_recursion_config(); + let (data, proof_with_pub_input) = mt_circuit.build_and_prove(config,leaf_index).unwrap(); + println!("build and prove time is: {:?}", start_build.elapsed()); + + let vd = data.verifier_data(); + let pub_input = proof_with_pub_input.public_inputs; + let proof = proof_with_pub_input.proof; + + // ---- verifier zone ---- + let start_verifier = Instant::now(); + assert!(mt_circuit.verify(&vd,pub_input,proof).is_ok()); + println!("verify time is: {:?}", start_verifier.elapsed()); + + Ok(()) + } + + #[test] + fn mod_test_merkle_circuit() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + // create Merkle tree + let log_n = 8; + let n = 1 << log_n; + let cap_height = 0; + let leaves = random_data::(n, LEAF_LEN); + let tree = MerkleTree::>::Hasher>::new(leaves, cap_height); + + // Build circuit + let start_build = Instant::now(); + let mut mt_circuit = MerkleTreeCircuit{ tree: tree.clone(), _phantom: Default::default() }; + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let targets = mt_circuit.build_circuit(&mut builder); + let data = builder.build::(); + let vd = data.verifier_data(); + println!("build time is: {:?}", start_build.elapsed()); + + // Prover Zone + let start_prover = Instant::now(); + let mut pw = PartialWitness::new(); + let leaf_index: usize = OsRng.gen_range(0..n); + let proof = tree.prove(leaf_index); + mt_circuit.fill_targets(&mut pw, leaf_index, targets); + let proof_with_pub_input = mt_circuit.prove(data,pw).unwrap(); + let pub_input = proof_with_pub_input.public_inputs; + let proof = proof_with_pub_input.proof; + println!("prove time is: {:?}", start_prover.elapsed()); + + // Verifier zone + let start_verifier = Instant::now(); + assert!(mt_circuit.verify(&vd,pub_input,proof).is_ok()); + println!("verify time is: {:?}", start_verifier.elapsed()); + + Ok(()) + } +} \ No newline at end of file diff --git a/codex-plonky2-circuits/src/circuits/mod.rs b/codex-plonky2-circuits/src/circuits/mod.rs new file mode 100644 index 0000000..2fc0c5c --- /dev/null +++ b/codex-plonky2-circuits/src/circuits/mod.rs @@ -0,0 +1 @@ +pub mod merkle_tree_circuit; \ No newline at end of file diff --git a/codex-plonky2-circuits/src/lib.rs b/codex-plonky2-circuits/src/lib.rs new file mode 100644 index 0000000..1034758 --- /dev/null +++ b/codex-plonky2-circuits/src/lib.rs @@ -0,0 +1,2 @@ +pub mod circuits; +pub mod merkle_tree; \ No newline at end of file diff --git a/codex-plonky2-circuits/src/merkle_tree/mod.rs b/codex-plonky2-circuits/src/merkle_tree/mod.rs new file mode 100644 index 0000000..2c8be2b --- /dev/null +++ b/codex-plonky2-circuits/src/merkle_tree/mod.rs @@ -0,0 +1,378 @@ +// An adapted implementation of Merkle tree +// based on the original plonky2 merkle tree implementation + +use core::mem::MaybeUninit; +use core::slice; +use anyhow::{ensure, Result}; +use plonky2_maybe_rayon::*; +use serde::{Deserialize, Serialize}; + +use plonky2::hash::hash_types::{HashOutTarget, RichField}; +use plonky2::plonk::config::{GenericHashOut, Hasher}; +use plonky2::util::log2_strict; + +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +#[serde(bound = "")] +pub struct MerkleCap>(pub Vec); + +impl> Default for MerkleCap { + fn default() -> Self { + Self(Vec::new()) + } +} + +impl> MerkleCap { + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn height(&self) -> usize { + log2_strict(self.len()) + } + + pub fn flatten(&self) -> Vec { + self.0.iter().flat_map(|&h| h.to_vec()).collect() + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct MerkleTree> { + pub leaves: Vec>, + + pub digests: Vec, + + pub cap: MerkleCap, +} + +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +#[serde(bound = "")] +pub struct MerkleProof> { + /// The Merkle digest of each sibling subtree, staying from the bottommost layer. + pub siblings: Vec, +} + +impl> MerkleProof { + pub fn len(&self) -> usize { + self.siblings.len() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct MerkleProofTarget { + /// The Merkle digest of each sibling subtree, staying from the bottommost layer. + pub siblings: Vec, +} + +impl> Default for MerkleTree { + fn default() -> Self { + Self { + leaves: Vec::new(), + digests: Vec::new(), + cap: MerkleCap::default(), + } + } +} + +pub(crate) fn capacity_up_to_mut(v: &mut Vec, len: usize) -> &mut [MaybeUninit] { + assert!(v.capacity() >= len); + let v_ptr = v.as_mut_ptr().cast::>(); + unsafe { + slice::from_raw_parts_mut(v_ptr, len) + } +} + +pub(crate) fn fill_subtree>( + digests_buf: &mut [MaybeUninit], + leaves: &[Vec], +) -> H::Hash { + assert_eq!(leaves.len(), digests_buf.len() / 2 + 1); + if digests_buf.is_empty() { + H::hash_or_noop(&leaves[0]) + } else { + let (left_digests_buf, right_digests_buf) = digests_buf.split_at_mut(digests_buf.len() / 2); + let (left_digest_mem, left_digests_buf) = left_digests_buf.split_last_mut().unwrap(); + let (right_digest_mem, right_digests_buf) = right_digests_buf.split_first_mut().unwrap(); + + let (left_leaves, right_leaves) = leaves.split_at(leaves.len() / 2); + + let (left_digest, right_digest) = plonky2_maybe_rayon::join( + || fill_subtree::(left_digests_buf, left_leaves), + || fill_subtree::(right_digests_buf, right_leaves), + ); + + left_digest_mem.write(left_digest); + right_digest_mem.write(right_digest); + H::two_to_one(left_digest, right_digest) + } +} + +pub(crate) fn fill_digests_buf>( + digests_buf: &mut [MaybeUninit], + cap_buf: &mut [MaybeUninit], + leaves: &[Vec], + cap_height: usize, +) { + + if digests_buf.is_empty() { + debug_assert_eq!(cap_buf.len(), leaves.len()); + cap_buf + .par_iter_mut() + .zip(leaves) + .for_each(|(cap_buf, leaf)| { + cap_buf.write(H::hash_or_noop(leaf)); + }); + return; + } + + let subtree_digests_len = digests_buf.len() >> cap_height; + let subtree_leaves_len = leaves.len() >> cap_height; + let digests_chunks = digests_buf.par_chunks_exact_mut(subtree_digests_len); + let leaves_chunks = leaves.par_chunks_exact(subtree_leaves_len); + assert_eq!(digests_chunks.len(), cap_buf.len()); + assert_eq!(digests_chunks.len(), leaves_chunks.len()); + digests_chunks.zip(cap_buf).zip(leaves_chunks).for_each( + |((subtree_digests, subtree_cap), subtree_leaves)| { + + subtree_cap.write(fill_subtree::(subtree_digests, subtree_leaves)); + }, + ); +} + +pub(crate) fn merkle_tree_prove>( + leaf_index: usize, + leaves_len: usize, + cap_height: usize, + digests: &[H::Hash], +) -> Vec { + let num_layers = log2_strict(leaves_len) - cap_height; + debug_assert_eq!(leaf_index >> (cap_height + num_layers), 0); + + let digest_len = 2 * (leaves_len - (1 << cap_height)); + assert_eq!(digest_len, digests.len()); + + let digest_tree: &[H::Hash] = { + let tree_index = leaf_index >> num_layers; + let tree_len = digest_len >> cap_height; + &digests[tree_len * tree_index..tree_len * (tree_index + 1)] + }; + + // Mask out high bits to get the index within the sub-tree. + let mut pair_index = leaf_index & ((1 << num_layers) - 1); + (0..num_layers) + .map(|i| { + let parity = pair_index & 1; + pair_index >>= 1; + + // The layers' data is interleaved as follows: + // [layer 0, layer 1, layer 0, layer 2, layer 0, layer 1, layer 0, layer 3, ...]. + // Each of the above is a pair of siblings. + // `pair_index` is the index of the pair within layer `i`. + // The index of that the pair within `digests` is + // `pair_index * 2 ** (i + 1) + (2 ** i - 1)`. + let siblings_index = (pair_index << (i + 1)) + (1 << i) - 1; + // We have an index for the _pair_, but we want the index of the _sibling_. + // Double the pair index to get the index of the left sibling. Conditionally add `1` + // if we are to retrieve the right sibling. + let sibling_index = 2 * siblings_index + (1 - parity); + digest_tree[sibling_index] + }) + .collect() +} + +impl> MerkleTree { + pub fn new(leaves: Vec>, cap_height: usize) -> Self { + let log2_leaves_len = log2_strict(leaves.len()); + assert!( + cap_height <= log2_leaves_len, + "cap_height={} should be at most log2(leaves.len())={}", + cap_height, + log2_leaves_len + ); + + let num_digests = 2 * (leaves.len() - (1 << cap_height)); + let mut digests = Vec::with_capacity(num_digests); + + let len_cap = 1 << cap_height; + let mut cap = Vec::with_capacity(len_cap); + + let digests_buf = capacity_up_to_mut(&mut digests, num_digests); + let cap_buf = capacity_up_to_mut(&mut cap, len_cap); + fill_digests_buf::(digests_buf, cap_buf, &leaves[..], cap_height); + + unsafe { + // SAFETY: `fill_digests_buf` and `cap` initialized the spare capacity up to + // `num_digests` and `len_cap`, resp. + digests.set_len(num_digests); + cap.set_len(len_cap); + } + + Self { + leaves, + digests, + cap: MerkleCap(cap), + } + } + + pub fn get(&self, i: usize) -> &[F] { + &self.leaves[i] + } + + // Create a Merkle proof from a leaf index. + pub fn prove(&self, leaf_index: usize) -> MerkleProof { + let cap_height = log2_strict(self.cap.len()); + let siblings = + merkle_tree_prove::(leaf_index, self.leaves.len(), cap_height, &self.digests); + + MerkleProof { siblings } + } +} + +/// Verifies that the given leaf data is present at the given index in the Merkle tree with the +/// given root. +pub fn verify_merkle_proof>( + leaf_data: Vec, + leaf_index: usize, + merkle_root: H::Hash, + proof: &MerkleProof, +) -> Result<()> { + let merkle_cap = MerkleCap(vec![merkle_root]); + verify_merkle_proof_to_cap(leaf_data, leaf_index, &merkle_cap, proof) +} + +/// Verifies that the given leaf data is present at the given index in the Merkle tree with the +/// given cap. +pub fn verify_merkle_proof_to_cap>( + leaf_data: Vec, + leaf_index: usize, + merkle_cap: &MerkleCap, + proof: &MerkleProof, +) -> Result<()> { + verify_batch_merkle_proof_to_cap( + &[leaf_data.clone()], + &[proof.siblings.len()], + leaf_index, + merkle_cap, + proof, + ) +} + +/// Verifies that the given leaf data is present at the given index in the Field Merkle tree with the +/// given cap. +pub fn verify_batch_merkle_proof_to_cap>( + leaf_data: &[Vec], + leaf_heights: &[usize], + mut leaf_index: usize, + merkle_cap: &MerkleCap, + proof: &MerkleProof, +) -> Result<()> { + assert_eq!(leaf_data.len(), leaf_heights.len()); + let mut current_digest = H::hash_or_noop(&leaf_data[0]); + let mut current_height = leaf_heights[0]; + let mut leaf_data_index = 1; + for &sibling_digest in &proof.siblings { + let bit = leaf_index & 1; + leaf_index >>= 1; + current_digest = if bit == 1 { + H::two_to_one(sibling_digest, current_digest) + } else { + H::two_to_one(current_digest, sibling_digest) + }; + current_height -= 1; + + if leaf_data_index < leaf_heights.len() && current_height == leaf_heights[leaf_data_index] { + let mut new_leaves = current_digest.to_vec(); + new_leaves.extend_from_slice(&leaf_data[leaf_data_index]); + current_digest = H::hash_or_noop(&new_leaves); + leaf_data_index += 1; + } + } + assert_eq!(leaf_data_index, leaf_data.len()); + ensure!( + current_digest == merkle_cap.0[leaf_index], + "Invalid Merkle proof." + ); + + Ok(()) +} + +#[cfg(test)] +pub(crate) mod tests { + use anyhow::Result; + + use super::*; + use plonky2::field::extension::Extendable; + use crate::merkle_tree::verify_merkle_proof_to_cap; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + pub(crate) fn random_data(n: usize, k: usize) -> Vec> { + (0..n).map(|_| F::rand_vec(k)).collect() + } + + fn verify_all_leaves< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, + >( + leaves: Vec>, + cap_height: usize, + ) -> Result<()> { + let tree = MerkleTree::::new(leaves.clone(), cap_height); + for (i, leaf) in leaves.into_iter().enumerate() { + let proof = tree.prove(i); + verify_merkle_proof_to_cap(leaf, i, &tree.cap, &proof)?; + } + Ok(()) + } + + #[test] + #[should_panic] + fn test_cap_height_too_big() { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let log_n = 8; + let cap_height = log_n + 1; // Should panic if `cap_height > len_n`. + + let leaves = random_data::(1 << log_n, 7); + let _ = MerkleTree::>::Hasher>::new(leaves, cap_height); + } + + #[test] + fn test_cap_height_eq_log2_len() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let log_n = 8; + let n = 1 << log_n; + let leaves = random_data::(n, 7); + + verify_all_leaves::(leaves, log_n)?; + + Ok(()) + } + + #[test] + fn test_merkle_trees() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let log_n = 8; + let n = 1 << log_n; + let leaves = random_data::(n, 7); + + verify_all_leaves::(leaves, 1)?; + + Ok(()) + } +} diff --git a/plonky2_poseidon2/.gitignore b/plonky2_poseidon2/.gitignore new file mode 100644 index 0000000..a3547b9 --- /dev/null +++ b/plonky2_poseidon2/.gitignore @@ -0,0 +1,13 @@ +#IDE Related +.idea + +# Cargo build +/target +Cargo.lock + +# Profile-guided optimization +/tmp +pgo-data.profdata + +# MacOS nuisances +.DS_Store diff --git a/plonky2_poseidon2/BENCHMARKS.md b/plonky2_poseidon2/BENCHMARKS.md new file mode 100644 index 0000000..5132c04 --- /dev/null +++ b/plonky2_poseidon2/BENCHMARKS.md @@ -0,0 +1,82 @@ +## Benchmark Results + +Benchmarks comparing the performance of Poseidon and Poseidon2 hash functions within the Plonky2 proving system. The benchmarks measure the time taken to build the circuit, generate the proof, and verify the proof for different numbers of permutations (from 210 to 213 permutations). + +## Running Benchmarks + +To run the benchmarks provided in this crate, you can use the following command: + +```bash +cargo bench --bench poseidon2_perm +``` + +The following operations were benchmarked: + +- **Build Circuit**: Time taken to construct the circuit for the specified number of permutations. +- **Prove Circuit**: Time taken to generate a proof for the constructed circuit. +- **Verify Circuit**: Time taken to verify the generated proof. + +#### Build Time + +| Number of Permutations | Poseidon Build Time (ms) | Poseidon2 Build Time (ms) | +|------------------------|------------------|------------------| +| 210 (1024) | 52.5 | 59.2 | +| 211 (2048) | 114.5 | 120.5 | +| 212 (4096) | 250.4 | 253.6 | +| 213 (8192) | 524.3 | 525.2 | + +#### Prove Time + +| Number of Permutations | Poseidon Prove Time (ms) | Poseidon2 Prove Time (ms) | +|------------------------|------------------|-------------------| +| 210 (1024) | 90.5 | 96.4 | +| 211 (2048) | 184.3 | 193.9 | +| 212 (4096) | 334.6 | 355.9 | +| 213 (8192) | 733.4 | 713.0 | + +#### Verify Time + +| Number of Permutations | Poseidon Verify Time (ms) | Poseidon2 Verify Time (ms) | +|------------------------|-------------------|--------------------| +| 210 (1024) | 2.7 | 2.8 | +| 211 (2048) | 2.9 | 3.0 | +| 212 (4096) | 3.0 | 3.2 | +| 213 (8192) | 3.4 | 3.7 | + +#### Circuit Size + +| Number of Permutations | Circuit Size (Gates) | +|------------------------|------------------------------| +| 210 (1024) | 211 (2048) gates | +| 211 (2048) | 212 (4096) gates | +| 212 (4096) | 213 (8192) gates | +| 213 (8192) | 214 (16384) gates | + +#### Proof Size + +| Number of Permutations | Proof Size (bytes) | +|------------------------|--------------------| +| 210 (1024) | 121,608 | +| 211 (2048) | 127,112 | +| 212 (4096) | 132,744 | +| 213 (8192) | 146,276 | + +#### Peak Memory Usage +The peak memory usage for both poseidon and poseidon2 is similar and shown in the table: +Memory usage varies between runs, the following values represent the average of 5 runs. + +| Number of Permutations | Build Memory Usage | Build & Prove Memory Usage | +|------------------------|--------------------|----------------------------| +| 2^10 | 355.70 MB | 441.61 MB | +| 2^11 | 336.00 MB | 445.07 MB | +| 2^12 | 342.52 MB | 459.70 MB | +| 2^13 | 499.73 MB | 714.11 MB | + + +### Remarks + +- **Build Circuit Time**: Poseidon2 shows a bit higher build times compared to Poseidon, especially at smaller circuit sizes. +- **Prove Circuit Time**: Both hash functions have similar prove times - Poseidon2 sometimes a little faster at larger sizes. +- **Verify Circuit Time**: Verification times are slightly higher for Poseidon2, but the difference is not much. + +Overall, this is just preliminary results and can/should be optimized further. diff --git a/plonky2_poseidon2/Cargo.toml b/plonky2_poseidon2/Cargo.toml new file mode 100644 index 0000000..2472579 --- /dev/null +++ b/plonky2_poseidon2/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "plonky2_poseidon2" +description = "Plonky2 with Poseidon2 hash" +authors = ["Mohammed Alghazwi "] +readme = "README.md" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = { version = "1.0.89" } +unroll = { version = "0.1.5", default-features = false } +serde = { version = "1.0.210" , features = ["rc"] } +serde_json = { version = "1.0" } +plonky2 = { version = "0.2.2" } +plonky2_field = { version = "0.2.2", default-features = false } + +[dev-dependencies] +criterion = { version = "0.5.1", default-features = false } +tynm = { version = "0.1.6", default-features = false } + +[[bench]] +name = "poseidon2_perm" +harness = false diff --git a/plonky2_poseidon2/README.md b/plonky2_poseidon2/README.md new file mode 100644 index 0000000..b106cad --- /dev/null +++ b/plonky2_poseidon2/README.md @@ -0,0 +1,73 @@ +# Poseidon2 Plonky2 +WARNING: This is a work-in-progress prototype, and has not received careful code review. This implementation is NOT ready for production use. + +This crate is an implementation of the Poseidon2 Hash that can be employed in the [Plonky2 proving system](https://github.com/0xPolygonZero/plonky2). Poseidon2 hash function is a new zk-friendly hash function, and provides good performance. +The hash and gate implementations are based on the plonky2 Poseidon [hash](https://github.com/0xPolygonZero/plonky2/blob/main/plonky2/src/hash/poseidon.rs) and [gate](https://github.com/0xPolygonZero/plonky2/blob/main/plonky2/src/gates/poseidon.rs). + +The Poseidon2 Hash implementation is consistent with that in here: https://github.com/HorizenLabs/poseidon2 + +## Code Organization +This crate include: + +- [**Poseidon2 Gate**](./src/gate/poseidon2.rs) +- [**Poseidon2 Hash**](./src/poseidon2_hash/poseidon2.rs) +- [**Poseidon2 Config**](./src/config/mod.rs) +- [**Benchmarks**](./benches/poseidon2_perm.rs) + +This crate can be used to: + +- Generate Plonky2 proofs employing the Poseidon2 hash function +- Write Plonky2 circuits computing Poseidon2 hashes + +## Building + +This crate requires the Rust nightly compiler due to the use of certain unstable features. To install the nightly toolchain, use `rustup`: + +```bash +rustup install nightly +``` + +To ensure that the nightly toolchain is used when building this crate, you can set the override in the project directory: + +```bash +rustup override set nightly +``` + +Alternatively, you can specify the nightly toolchain when building: + +```bash +cargo +nightly build +``` + +## Usage + +The Poseidon2 hash can be used directly to compute hash values over an array of field elements. Below is a simplified example demonstrating how to use the Poseidon2 hash function: + +```rust +use crate::poseidon2_hash::poseidon2::{Poseidon2, SPONGE_WIDTH}; +use plonky2_field::goldilocks_field::GoldilocksField as F; +use plonky2_field::types::Field; + +fn main() { + // Create an input array of field elements for hashing + let mut input = [F::ZERO; SPONGE_WIDTH]; + // [0,1,2,3,4,5,6,7,8,9,10,11] + for i in 0..SPONGE_WIDTH { + input[i] = F::from_canonical_u64(i as u64); + } + // Compute the Poseidon2 hash + let output = F::poseidon2(input); + // Print the input values + for i in 0..SPONGE_WIDTH { + println!("input {} = {}", i, input[i]); + } + // Print the output values + for i in 0..SPONGE_WIDTH { + println!("out {} = {}", i, output[i]); + } +} +``` + +## Benchmark Results + +Benchmark results are shown in [BENCHMARKS.md](./BENCHMARKS.md) \ No newline at end of file diff --git a/plonky2_poseidon2/benches/poseidon2_perm.rs b/plonky2_poseidon2/benches/poseidon2_perm.rs new file mode 100644 index 0000000..0c90383 --- /dev/null +++ b/plonky2_poseidon2/benches/poseidon2_perm.rs @@ -0,0 +1,184 @@ +use std::fs; +use anyhow::Result; +use std::time::Instant; +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use plonky2::field::extension::Extendable; +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::poseidon::PoseidonHash; +use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData}; +use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher, PoseidonGoldilocksConfig}; +use plonky2_poseidon2::config::Poseidon2GoldilocksConfig; +use tynm::type_name; +use plonky2::hash::hashing::PlonkyPermutation; +use plonky2_poseidon2::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Hash}; +use plonky2::iop::target::Target; +use plonky2::iop::witness::{PartialWitness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use std::marker::PhantomData; +use plonky2::plonk::proof::ProofWithPublicInputs; + +macro_rules! pretty_print { + ($($arg:tt)*) => { + print!("\x1b[0;36mINFO ===========>\x1b[0m "); + println!($($arg)*); + } +} + +pub struct PoseidonCircuit< + F: RichField + Extendable + Poseidon2, + C: GenericConfig, + const D: usize, + H: Hasher + AlgebraicHasher, +> { + public_input: Vec, + circuit_data: CircuitData, + num_powers: usize, + _hasher: PhantomData, +} + +impl< + F: RichField + Extendable + Poseidon2, + C: GenericConfig, + const D: usize, + H: Hasher + AlgebraicHasher, +> PoseidonCircuit +{ + pub fn build_circuit(config: CircuitConfig, log_num_hashes: usize) -> Self { + let num_hashes: usize = 1usize << log_num_hashes; + const T: usize = 12; + + let mut builder = CircuitBuilder::::new(config); + let zero = builder.zero(); + let mut state = H::AlgebraicPermutation::new(core::iter::repeat(zero)); + + let mut initial = Vec::new(); // vec![]; + for _ in 0..T { + let x = builder.add_virtual_public_input(); + initial.push(x); + } + + state.set_from_slice(&initial, 0); + + for k in 0..num_hashes { + state = builder.permute::(state); + } + + let output = state.squeeze(); + for o in output{ + builder.register_public_input(*o); + } + + let data = builder.build::(); + + Self { + public_input: initial, + circuit_data: data, + num_powers: num_hashes, + _hasher: PhantomData::, + } + } + + pub fn generate_proof(&self, init: F) -> Result> { + const T: usize = 12; + + let mut pw = PartialWitness::::new(); + for j in 0..T { + pw.set_target(self.public_input[j], F::from_canonical_usize(j)); + } + + let proof = self.circuit_data.prove(pw).unwrap(); + + Ok(proof) + } + + pub fn get_circuit_data(&self) -> &CircuitData { + &self.circuit_data + } +} + +fn bench_poseidon2_perm< + F: RichField + Extendable + Poseidon2, + const D: usize, + C: GenericConfig, + H: Hasher + AlgebraicHasher, +>( + c: &mut Criterion, + config: CircuitConfig, +) { + + let mut group = c.benchmark_group(&format!( + "poseidon-proof<{}, {}>", + type_name::(), + type_name::() + )); + + for log_num_hashes in [ 10, 11, 12, 13 ] { + group.bench_function( + format!("build circuit for 2^{} permutations", log_num_hashes).as_str(), + |b| { + b.iter_with_large_drop(|| { + PoseidonCircuit::::build_circuit(config.clone(), log_num_hashes); + }) + }, + ); + + let poseidon_circuit = + PoseidonCircuit::::build_circuit(config.clone(), log_num_hashes); + + pretty_print!( + "circuit size: 2^{} gates", + poseidon_circuit.get_circuit_data().common.degree_bits() + ); + + group.bench_function( + format!("prove circuit with 2^{} permutations", log_num_hashes).as_str(), + |b| { + b.iter_batched( + || F::rand(), + |init| poseidon_circuit.generate_proof(init).unwrap(), + BatchSize::PerIteration, + ) + }, + ); + + let proof = poseidon_circuit.generate_proof(F::rand()).unwrap(); + + pretty_print!("proof size: {}", proof.to_bytes().len()); + + group.bench_function( + format!("verify circuit with 2^{} permutations", log_num_hashes).as_str(), + |b| { + b.iter_batched( + || (poseidon_circuit.get_circuit_data(), proof.clone()), + |(data, proof)| data.verify(proof).unwrap(), + BatchSize::PerIteration, + ) + }, + ); + } + + group.finish(); +} + +fn benchmark(c: &mut Criterion) { + const D: usize = 2; + type F = GoldilocksField; + + // bench poseidon hash + bench_poseidon2_perm::( + c, + CircuitConfig::standard_recursion_config(), + ); + + // bench poseidon2 hash + bench_poseidon2_perm::( + c, + CircuitConfig::standard_recursion_config(), + ); +} + +criterion_group!(name = benches; + config = Criterion::default().sample_size(10); + targets = benchmark); +criterion_main!(benches); diff --git a/plonky2_poseidon2/src/config/mod.rs b/plonky2_poseidon2/src/config/mod.rs new file mode 100644 index 0000000..2409808 --- /dev/null +++ b/plonky2_poseidon2/src/config/mod.rs @@ -0,0 +1,15 @@ +use plonky2::plonk::config::GenericConfig; +use plonky2_field::extension::quadratic::QuadraticExtension; +use plonky2_field::goldilocks_field::GoldilocksField; +use serde::{Deserialize, Serialize}; +use crate::poseidon2_hash::poseidon2::Poseidon2Hash; + +/// Configuration using Poseidon2 over the Goldilocks field. +#[derive(Debug, Copy, Clone, Default, Eq, PartialEq, Serialize, Deserialize)] +pub struct Poseidon2GoldilocksConfig; +impl GenericConfig<2> for Poseidon2GoldilocksConfig { + type F = GoldilocksField; + type FE = QuadraticExtension; + type Hasher = Poseidon2Hash; + type InnerHasher = Poseidon2Hash; +} \ No newline at end of file diff --git a/plonky2_poseidon2/src/gate/mod.rs b/plonky2_poseidon2/src/gate/mod.rs new file mode 100644 index 0000000..4f84889 --- /dev/null +++ b/plonky2_poseidon2/src/gate/mod.rs @@ -0,0 +1 @@ +pub mod poseidon2; \ No newline at end of file diff --git a/plonky2_poseidon2/src/gate/poseidon2.rs b/plonky2_poseidon2/src/gate/poseidon2.rs new file mode 100644 index 0000000..08ed2bb --- /dev/null +++ b/plonky2_poseidon2/src/gate/poseidon2.rs @@ -0,0 +1,633 @@ +//! Implementation of the Poseidon2 hash function as Plonky2 Gate +//! based on Poseidon Gate: +//! https://github.com/0xPolygonZero/plonky2/blob/main/plonky2/src/gates/poseidon.rs +//! +use core::marker::PhantomData; +use std::ops::Mul; +use plonky2_field::extension::Extendable; +use plonky2_field::types::Field; +use plonky2::gates::gate::Gate; +use plonky2::gates::util::StridedConstraintConsumer; +use plonky2::hash::hash_types::RichField; +use crate::poseidon2_hash::poseidon2::{Poseidon2, FULL_ROUND_BEGIN, FULL_ROUND_END, PARTIAL_ROUNDS, SPONGE_WIDTH}; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGeneratorRef}; +use plonky2::iop::target::Target; +use plonky2::iop::wire::Wire; +use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::CommonCircuitData; +use plonky2::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use plonky2::util::serialization::{Buffer, IoResult, Read, Write}; + +/// Evaluates a full Poseidon2 permutation with 12 state elements. +/// +/// This also has some extra features to make it suitable for efficiently +/// verifying Merkle proofs. It has a flag which can be used to swap the first +/// four inputs with the next four, for ordering sibling digests. +#[derive(Debug, Default)] +pub struct Poseidon2Gate, const D: usize>(PhantomData); + +impl, const D: usize> Poseidon2Gate { + pub fn new() -> Self { + Self(PhantomData) + } + /// The wire index for the `i`th input to the permutation. + pub fn wire_input(i: usize) -> usize { + i + } + + /// The wire index for the `i`th output to the permutation. + pub fn wire_output(i: usize) -> usize { + SPONGE_WIDTH + i + } + /// If this is set to 1, the first four inputs will be swapped with the next + /// four inputs. This is useful for ordering hashes in Merkle proofs. + /// Otherwise, this should be set to 0. + pub const WIRE_SWAP: usize = 2 * SPONGE_WIDTH; + + const START_DELTA: usize = 2 * SPONGE_WIDTH + 1; + + /// A wire which stores `swap * (input[i + 4] - input[i])`; used to compute + /// the swapped inputs. + fn wire_delta(i: usize) -> usize { + assert!(i < 4); + Self::START_DELTA + i + } + + const START_FULL_ROUND_BEGIN: usize = Self::START_DELTA + 4; + + /// A wire which stores the input of the `i`-th S-box of the `round`-th + /// round of the first set of full rounds. + fn wire_first_full_round(round: usize, i: usize) -> usize { + debug_assert!( + round != 0, + "First round S-box inputs are not stored as wires" + ); + debug_assert!(round < FULL_ROUND_BEGIN); + debug_assert!(i < SPONGE_WIDTH); + Self::START_FULL_ROUND_BEGIN + SPONGE_WIDTH * (round - 1) + i + } + + const START_PARTIAL: usize = + Self::START_FULL_ROUND_BEGIN + SPONGE_WIDTH * (FULL_ROUND_BEGIN - 1); + + /// A wire which stores the input of the S-box of the `round`-th round of + /// the partial rounds. + fn wire_partial_round(round: usize) -> usize { + debug_assert!(round < PARTIAL_ROUNDS); + Self::START_PARTIAL + round + } + + const START_FULL_ROUND_END: usize = Self::START_PARTIAL + PARTIAL_ROUNDS; + + /// A wire which stores the input of the `i`-th S-box of the `round`-th + /// round of the second set of full rounds. + fn wire_second_full_round(round: usize, i: usize) -> usize { + debug_assert!(round < FULL_ROUND_BEGIN); + debug_assert!(i < SPONGE_WIDTH); + Self::START_FULL_ROUND_END + SPONGE_WIDTH * round + i + } + + /// End of wire indices, exclusive. + fn end() -> usize { + Self::START_FULL_ROUND_END + SPONGE_WIDTH * FULL_ROUND_BEGIN + } +} +impl + Poseidon2, const D: usize> Gate for Poseidon2Gate { + fn id(&self) -> String { + format!("{:?}", self, SPONGE_WIDTH) + } + + fn serialize( + &self, + _dst: &mut Vec, + _common_data: &CommonCircuitData, + ) -> IoResult<()> { + Ok(()) + } + fn deserialize(_src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult { + Ok(Poseidon2Gate::new()) + } + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + // Assert that `swap` is binary. + let swap = vars.local_wires[Self::WIRE_SWAP]; + constraints.push(swap * (swap - F::Extension::ONE)); + + // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. + for i in 0..4 { + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + let delta_i = vars.local_wires[Self::wire_delta(i)]; + constraints.push(swap * (input_rhs - input_lhs) - delta_i); + } + + // Compute the possibly-swapped input layer. + let mut state = [F::Extension::ZERO; SPONGE_WIDTH]; + for i in 0..4 { + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let input_lhs = Self::wire_input(i); + let input_rhs = Self::wire_input(i + 4); + state[i] = vars.local_wires[input_lhs] + delta_i; + state[i + 4] = vars.local_wires[input_rhs] - delta_i; + } + for i in 8..SPONGE_WIDTH { + state[i] = vars.local_wires[Self::wire_input(i)]; + } + + // linear layer + ::matmul_external_field(&mut state); + + // First External layer + for r in 0..FULL_ROUND_BEGIN { + ::constant_layer_field(&mut state, r); + if r != 0 { + for i in 0..SPONGE_WIDTH { + let sbox_in = + vars.local_wires[Self::wire_first_full_round(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } + } + ::sbox_layer_field(&mut state); + ::matmul_external_field(&mut state); + } + + // Internal layer + for r in 0..PARTIAL_ROUNDS { + state[0] += F::Extension::from_canonical_u64(::RC12_MID[r]); + + let sbox_in = + vars.local_wires[Self::wire_partial_round(r)]; + constraints.push(state[0] - sbox_in); + state[0] = ::sbox_p(sbox_in); + ::matmul_internal_field(&mut state, &::MAT_DIAG12_M_1); + } + + // Second External layer + for r in FULL_ROUND_BEGIN..FULL_ROUND_END { + ::constant_layer_field(&mut state, r); + + for i in 0..SPONGE_WIDTH { + let sbox_in = + vars.local_wires[Self::wire_second_full_round(r - FULL_ROUND_BEGIN, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } + + ::sbox_layer_field(&mut state); + ::matmul_external_field(&mut state); + } + + //12 constraints + for i in 0..SPONGE_WIDTH { + constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]); + } + + constraints + } + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { + // Assert that `swap` is binary. + let swap = vars.local_wires[Self::WIRE_SWAP]; + yield_constr.one(swap * swap.sub_one()); + + // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. + for i in 0..4 { + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + let delta_i = vars.local_wires[Self::wire_delta(i)]; + yield_constr.one(swap * (input_rhs - input_lhs) - delta_i); + } + + // Compute the possibly-swapped input layer. + let mut state = [F::ZERO; SPONGE_WIDTH]; + for i in 0..4 { + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let input_lhs = Self::wire_input(i); + let input_rhs = Self::wire_input(i + 4); + state[i] = vars.local_wires[input_lhs] + delta_i; + state[i + 4] = vars.local_wires[input_rhs] - delta_i; + } + for i in 8..SPONGE_WIDTH { + state[i] = vars.local_wires[Self::wire_input(i)]; + } + + // linear layer + ::matmul_external(&mut state); + + // First External layer + for r in 0..FULL_ROUND_BEGIN { + ::constant_layer(&mut state, r); + if r != 0 { + for i in 0..SPONGE_WIDTH { + let sbox_in = vars.local_wires[Self::wire_first_full_round(r, i)]; + yield_constr.one(state[i] - sbox_in); + state[i] = sbox_in; + } + } + ::sbox_layer(&mut state); + ::matmul_external(&mut state); + } + + // Internal layer + for r in 0..PARTIAL_ROUNDS { + state[0] += F::from_canonical_u64(::RC12_MID[r]); + let sbox_in = vars.local_wires[Self::wire_partial_round(r)]; + yield_constr.one(state[0] - sbox_in); + state[0] = sbox_in; + state[0] = ::sbox_p(state[0]); + ::matmul_internal(&mut state, &::MAT_DIAG12_M_1); + } + + // Second External layer + for r in FULL_ROUND_BEGIN..FULL_ROUND_END { + ::constant_layer(&mut state, r); + + for i in 0..SPONGE_WIDTH { + let sbox_in = vars.local_wires[Self::wire_second_full_round(r - FULL_ROUND_BEGIN, i)]; + yield_constr.one(state[i] - sbox_in); + state[i] = sbox_in; + } + + ::sbox_layer(&mut state); + ::matmul_external(&mut state); + } + + for i in 0..SPONGE_WIDTH { + yield_constr.one(state[i] - vars.local_wires[Self::wire_output(i)]); + } + } + fn eval_unfiltered_circuit( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + // Assert that `swap` is binary. + let swap = vars.local_wires[Self::WIRE_SWAP]; + constraints.push(builder.mul_sub_extension(swap, swap, swap)); + + // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. + for i in 0..4 { + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let diff = builder.sub_extension(input_rhs, input_lhs); + constraints.push(builder.mul_sub_extension(swap, diff, delta_i)); + } + + // Compute the possibly-swapped input layer. + let mut state = [builder.zero_extension(); SPONGE_WIDTH]; + for i in 0..4 { + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + state[i] = builder.add_extension(input_lhs, delta_i); + state[i + 4] = builder.sub_extension(input_rhs, delta_i); + } + for i in 8..SPONGE_WIDTH { + state[i] = vars.local_wires[Self::wire_input(i)]; + } + + // linear layer + state = ::matmul_external_circuit(builder, &mut state); + + // First External layer + for r in 0..FULL_ROUND_BEGIN { + ::constant_layer_circuit(builder, &mut state, r); + if r != 0 { + for i in 0..SPONGE_WIDTH { + let sbox_in = vars.local_wires[Self::wire_first_full_round(r, i)]; + constraints.push(builder.sub_extension(state[i], sbox_in)); + state[i] = sbox_in; + } + } + ::sbox_layer_circuit(builder, &mut state); + state = ::matmul_external_circuit(builder, &mut state); + } + + // Internal layer + for r in 0..PARTIAL_ROUNDS { + let round_constant = F::Extension::from_canonical_u64(::RC12_MID[r]); + let round_constant = builder.constant_extension(round_constant); + state[0] = builder.add_extension(state[0], round_constant); + + let sbox_in = vars.local_wires[Self::wire_partial_round(r)]; + constraints.push(builder.sub_extension(state[0], sbox_in)); + //state[0] = sbox_in; + state[0] = ::sbox_p_circuit(builder, sbox_in); + ::matmul_internal_circuit(builder, &mut state); + } + + // Second External layer + for r in FULL_ROUND_BEGIN..FULL_ROUND_END { + ::constant_layer_circuit(builder, &mut state, r); + + for i in 0..SPONGE_WIDTH { + let sbox_in = vars.local_wires[Self::wire_second_full_round(r - FULL_ROUND_BEGIN, i)]; + constraints.push(builder.sub_extension(state[i], sbox_in)); + state[i] = sbox_in; + } + + ::sbox_layer_circuit(builder, &mut state); + state = ::matmul_external_circuit(builder, &mut state); + } + + for i in 0..SPONGE_WIDTH { + constraints + .push(builder.sub_extension(state[i], vars.local_wires[Self::wire_output(i)])); + } + + constraints + } + fn generators(&self, row: usize, _local_constants: &[F]) -> Vec> { + let gen = Poseidon2Generator:: { + row, + _phantom: PhantomData, + }; + vec![WitnessGeneratorRef::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + Self::end() + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 7 + } + + fn num_constraints(&self) -> usize { + SPONGE_WIDTH * (FULL_ROUND_END - 1) + PARTIAL_ROUNDS + SPONGE_WIDTH + 1 + 4 + } + +} + +#[derive(Debug, Default)] +pub struct Poseidon2Generator + Poseidon2, const D: usize> { + row: usize, + _phantom: PhantomData, +} + +impl + Poseidon2, const D: usize> SimpleGenerator + for Poseidon2Generator +{ + fn id(&self) -> String { + "Poseidon2Generator".to_string() + } + fn dependencies(&self) -> Vec { + (0..SPONGE_WIDTH) + .map(|i| Poseidon2Gate::::wire_input(i)) + .chain(Some(Poseidon2Gate::::WIRE_SWAP)) + .map(|column| Target::wire(self.row, column)) + .collect() + } + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |column| Wire { + row: self.row, + column, + }; + + let mut state = (0..SPONGE_WIDTH) + .map(|i| witness.get_wire(local_wire(Poseidon2Gate::::wire_input(i)))) + .collect::>(); + + let swap_value = witness.get_wire(local_wire(Poseidon2Gate::::WIRE_SWAP)); + debug_assert!(swap_value == F::ZERO || swap_value == F::ONE); + + for i in 0..4 { + let delta_i = swap_value * (state[i + 4] - state[i]); + out_buffer.set_wire(local_wire(Poseidon2Gate::::wire_delta(i)), delta_i); + } + + if swap_value == F::ONE { + for i in 0..4 { + state.swap(i, 4 + i); + } + } + + let mut state: [F; SPONGE_WIDTH] = state.try_into().unwrap(); + + // Linear layer + ::matmul_external_field(&mut state); + + // first External layer + for r in 0..FULL_ROUND_BEGIN { + ::constant_layer_field(&mut state, r); + if r != 0 { + for i in 0..SPONGE_WIDTH { + out_buffer.set_wire( + local_wire(Poseidon2Gate::::wire_first_full_round(r, i)), + state[i], + ); + } + } + ::sbox_layer_field(&mut state); + ::matmul_external_field(&mut state); + } + + // Internal layer + for r in 0..PARTIAL_ROUNDS { + state[0] += F::from_canonical_u64(::RC12_MID[r]); + out_buffer.set_wire( + local_wire(Poseidon2Gate::::wire_partial_round(r)), + state[0], + ); + state[0] = ::sbox_p(state[0]); + ::matmul_internal_field(&mut state, &::MAT_DIAG12_M_1); + } + + // Second External layer + for r in FULL_ROUND_BEGIN..FULL_ROUND_END { + ::constant_layer_field(&mut state, r); + + for i in 0..SPONGE_WIDTH { + out_buffer.set_wire( + local_wire(Poseidon2Gate::::wire_second_full_round( + r - FULL_ROUND_BEGIN, + i, + )), + state[i], + ); + } + + ::sbox_layer_field(&mut state); + ::matmul_external_field(&mut state); + } + + for i in 0..SPONGE_WIDTH { + out_buffer.set_wire(local_wire(Poseidon2Gate::::wire_output(i)), state[i]); + } + } + + fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { + dst.write_usize(self.row) + } + + fn deserialize(src: &mut Buffer, _common_data: &CommonCircuitData) -> IoResult { + let row = src.read_usize()?; + Ok(Self { + row, + _phantom: PhantomData, + }) + } +} + +//------------------------------------- Tests ----------------------------------------- + +#[cfg(test)] +mod tests { + + use anyhow::Result; + use plonky2_field::goldilocks_field::GoldilocksField; + use plonky2_field::types::Field; + use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gate::poseidon2::Poseidon2Gate; + use crate::poseidon2_hash::poseidon2::{Poseidon2, SPONGE_WIDTH}; + use plonky2::iop::generator::generate_partial_witness; + use plonky2::iop::wire::Wire; + use plonky2::iop::witness::{PartialWitness, Witness, WitnessWrite}; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::GenericConfig; + use crate::config::Poseidon2GoldilocksConfig; + + #[test] + fn wire_indices() { + type F = GoldilocksField; + type Gate = Poseidon2Gate; + + assert_eq!(Gate::wire_input(0), 0); + assert_eq!(Gate::wire_input(11), 11); + assert_eq!(Gate::wire_output(0), 12); + assert_eq!(Gate::wire_output(11), 23); + assert_eq!(Gate::WIRE_SWAP, 24); + assert_eq!(Gate::wire_delta(0), 25); + assert_eq!(Gate::wire_delta(3), 28); + } + + #[test] + fn generated_output() { + const D: usize = 2; + type C = Poseidon2GoldilocksConfig; + type F = >::F; + + let config = CircuitConfig { + num_wires: 143, + ..CircuitConfig::standard_recursion_config() + }; + let mut builder = CircuitBuilder::new(config); + type Gate = Poseidon2Gate; + let gate = Gate::new(); + let row = builder.add_gate(gate, vec![]); + let circuit = builder.build_prover::(); + + println!("width = {}", SPONGE_WIDTH); + + let permutation_inputs = (0..SPONGE_WIDTH).map(F::from_canonical_usize).collect::>(); + + for i in 0..SPONGE_WIDTH { + println!("out {} = {}", i, permutation_inputs[i].clone()); + } + + let mut inputs = PartialWitness::new(); + inputs.set_wire( + Wire { + row, + column: Gate::WIRE_SWAP, + }, + F::ZERO, + ); + for i in 0..SPONGE_WIDTH { + inputs.set_wire( + Wire { + row, + column: Gate::wire_input(i), + }, + permutation_inputs[i], + ); + } + + let witness = generate_partial_witness(inputs, &circuit.prover_only, &circuit.common); + + let expected_outputs: [F; SPONGE_WIDTH] = F::poseidon2(permutation_inputs.try_into().unwrap()); + for i in 0..SPONGE_WIDTH { + let out = witness.get_wire(Wire { + row: 0, + column: Gate::wire_output(i), + }); + println!("out {} = {}", i, out.clone()); + assert_eq!(out, expected_outputs[i]); + } + } + + #[test] + fn low_degree() { + type F = GoldilocksField; + let gate = Poseidon2Gate::::new(); + test_low_degree(gate) + } + + #[test] + fn eval_fns() -> Result<()> { + const D: usize = 2; + type C = Poseidon2GoldilocksConfig; + type F = >::F; + let gate = Poseidon2Gate::::new(); + test_eval_fns::(gate) + } + + #[test] + fn test_proof() { + use plonky2_field::types::Sample; + use plonky2::gates::gate::Gate; + use plonky2::hash::hash_types::HashOut; + use plonky2::plonk::vars::{EvaluationTargets, EvaluationVars}; + const D: usize = 2; + type C = Poseidon2GoldilocksConfig; + type F = >::F; + let gate = Poseidon2Gate::::new(); + let wires = <>::F as plonky2_field::extension::Extendable>::Extension::rand_vec(gate.num_wires()); + let constants = <>::F as plonky2_field::extension::Extendable>::Extension::rand_vec(gate.num_constants()); + let public_inputs_hash = HashOut::rand(); + + let config = CircuitConfig::standard_recursion_config(); + let mut pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let wires_t = builder.add_virtual_extension_targets(wires.len()); + let constants_t = builder.add_virtual_extension_targets(constants.len()); + pw.set_extension_targets(&wires_t, &wires); + pw.set_extension_targets(&constants_t, &constants); + let public_inputs_hash_t = builder.add_virtual_hash(); + pw.set_hash_target(public_inputs_hash_t, public_inputs_hash); + + let vars = EvaluationVars { + local_constants: &constants, + local_wires: &wires, + public_inputs_hash: &public_inputs_hash, + }; + let evals = gate.eval_unfiltered(vars); + + let vars_t = EvaluationTargets { + local_constants: &constants_t, + local_wires: &wires_t, + public_inputs_hash: &public_inputs_hash_t, + }; + let evals_t = gate.eval_unfiltered_circuit(&mut builder, vars_t); + pw.set_extension_targets(&evals_t, &evals); + let data = builder.build::(); + let proof = data.prove(pw); + assert!(proof.is_ok()); + } +} \ No newline at end of file diff --git a/plonky2_poseidon2/src/lib.rs b/plonky2_poseidon2/src/lib.rs new file mode 100644 index 0000000..eae341f --- /dev/null +++ b/plonky2_poseidon2/src/lib.rs @@ -0,0 +1,3 @@ +pub mod gate; +pub mod poseidon2_hash; +pub mod config; \ No newline at end of file diff --git a/plonky2_poseidon2/src/poseidon2_hash/mod.rs b/plonky2_poseidon2/src/poseidon2_hash/mod.rs new file mode 100644 index 0000000..5adf6b9 --- /dev/null +++ b/plonky2_poseidon2/src/poseidon2_hash/mod.rs @@ -0,0 +1,6 @@ +pub mod poseidon2; +pub mod poseidon2_goldilocks; + +use plonky2::field::types::{Field, PrimeField64, Sample}; +use plonky2::hash::poseidon::Poseidon; +use crate::poseidon2_hash::poseidon2::Poseidon2; \ No newline at end of file diff --git a/plonky2_poseidon2/src/poseidon2_hash/poseidon2.rs b/plonky2_poseidon2/src/poseidon2_hash/poseidon2.rs new file mode 100644 index 0000000..8a423a4 --- /dev/null +++ b/plonky2_poseidon2/src/poseidon2_hash/poseidon2.rs @@ -0,0 +1,549 @@ +//! Implementation of the Poseidon2 hash function, as described in +//! https://eprint.iacr.org/2023/323.pdf +//! The implementation is based on Poseidon hash in Plonky2: +//! https://github.com/0xPolygonZero/plonky2/blob/main/plonky2/src/hash/poseidon.rs + +use core::fmt::Debug; +use plonky2_field::extension::{Extendable, FieldExtension}; +use plonky2_field::types::{Field, PrimeField64}; +use unroll::unroll_for_loops; +use crate::gate::poseidon2::Poseidon2Gate; +use plonky2::hash::hash_types::{HashOut, RichField}; +use plonky2::hash::hashing::{compress, hash_n_to_hash_no_pad, PlonkyPermutation}; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::config::{AlgebraicHasher, Hasher}; + +// Constants defining the number of rounds and state width. +// Note: only state width 12 is currently supported. +pub const SPONGE_WIDTH: usize = 12; // state width +pub const DEGREE: usize = 7; // sbox degree +pub const FULL_ROUND_BEGIN: usize = 4; +pub const FULL_ROUND_END: usize = 2 * FULL_ROUND_BEGIN; +pub const PARTIAL_ROUNDS: usize = 22; +pub const ROUNDS: usize = FULL_ROUND_END + PARTIAL_ROUNDS; + + +pub trait Poseidon2: PrimeField64 { + const MAT_DIAG12_M_1: [u64; SPONGE_WIDTH]; + const RC12: [u64; SPONGE_WIDTH * FULL_ROUND_END]; + const RC12_MID: [u64; PARTIAL_ROUNDS]; + + // ------------- Poseidon2 Hash ------------ + #[inline] + fn poseidon2(input: [Self; SPONGE_WIDTH]) -> [Self; SPONGE_WIDTH] { + // state + let mut current_state = input; + + // Linear layer at beginning + Self::matmul_external(&mut current_state); + + // External Rounds 0 -> 4 + for round_ctr in 0..FULL_ROUND_BEGIN { + Self::external_rounds(&mut current_state , round_ctr); + } + + // Internal Rounds + for round_ctr in 0..PARTIAL_ROUNDS { + Self::internal_rounds(&mut current_state ,round_ctr); + } + + // External Rounds 4 -> 8 + for round_ctr in FULL_ROUND_BEGIN..FULL_ROUND_END { + Self::external_rounds(&mut current_state , round_ctr); + } + + current_state + } + + // ------------- matmul external and internal ------------------- + #[inline] + #[unroll_for_loops] + fn matmul_external(state: &mut [Self; SPONGE_WIDTH]){ + // Applying cheap 4x4 MDS matrix to each 4-element part of the state + Self::matmul_m4(state); + + // Applying second cheap matrix for t > 4 + let t4: usize = SPONGE_WIDTH / 4; + let mut stored = [Self::ZERO; 4]; + for l in 0..4 { + stored[l] = state[l]; + for j in 1..t4 { + stored[l] = stored[l].add(state[4 * j + l]); + } + } + for i in 0..state.len() { + state[i] = state[i].add(stored[i % 4]); + } + } + + fn matmul_m4 (state: &mut [Self; SPONGE_WIDTH]){ + let t4 = SPONGE_WIDTH / 4; + + for i in 0..t4 { + let start_index = i * 4; + let mut t_0 = state[start_index]; + + t_0 = t_0.add(state[start_index + 1]); + let mut t_1 = state[start_index + 2]; + + t_1 = t_1.add(state[start_index + 3]); + let mut t_2 = t_1; + + t_2 = t_2.multiply_accumulate(state[start_index + 1], Self::TWO); + + let mut t_3 = t_0; + + t_3 = t_3.multiply_accumulate(state[start_index + 3], Self::TWO); + let mut t_4 = t_3; + + t_4 = t_4.multiply_accumulate(t_1, Self::TWO.double()); + + let mut t_5 = t_2; + + t_5 = t_5.multiply_accumulate(t_0, Self::TWO.double()); + + let t_6 = t_3.add(t_5); + + let t_7 = t_2.add(t_4); + + state[start_index] = t_6; + state[start_index + 1] = t_5; + state[start_index + 2] = t_7; + state[start_index + 3] = t_4; + } + } + + #[inline] + #[unroll_for_loops] + fn matmul_internal(current_state: &mut [Self; SPONGE_WIDTH], mat_internal_diag_m_1: &[u64; SPONGE_WIDTH]){ + let sum: u128 = current_state + .iter() + .map(|&x| x.to_noncanonical_u64() as u128) + .sum(); + + current_state + .iter_mut() + .zip(mat_internal_diag_m_1.iter()) + .for_each(|(state_i, &diag_m1)| { + let state_value = state_i.to_noncanonical_u64() as u128; + let multi = (diag_m1 as u128) * state_value + sum; + *state_i = Self::from_noncanonical_u128(multi); + }); + } + + // ------------- external rounds ------------------- + fn external_rounds(state: &mut [Self; SPONGE_WIDTH], round_ctr: usize) { + Self::constant_layer(state, round_ctr); + Self::sbox_layer(state); + Self::matmul_external(state); + } + + // Constant Layer + #[inline] + #[unroll_for_loops] + fn constant_layer(state: &mut [Self; SPONGE_WIDTH], round_ctr: usize) { + let ofs = round_ctr * SPONGE_WIDTH; + for i in 0..SPONGE_WIDTH { + let round_constant = Self::RC12[ofs + i]; + unsafe { + state[i] = state[i].add_canonical_u64(round_constant); + } + } + } + + // sbox layer + #[inline] + #[unroll_for_loops] + fn sbox_layer(state: &mut [Self; SPONGE_WIDTH]) { + for i in 0..SPONGE_WIDTH { + state[i] = Self::sbox_p(state[i]); + } + } + #[inline(always)] + fn sbox_p, const D: usize>(x: F) -> F { + // x |--> x^7 + // only d=7 is supported for now + if DEGREE != 7 { panic!("sbox degree not supported") } + let x2 = x.square(); + let x4 = x2.square(); + let x3 = x * x2; + x3 * x4 + } + + // ------------- internal rounds ------------------- + fn internal_rounds(state: &mut [Self; SPONGE_WIDTH], round_ctr: usize) { + state[0] += Self::from_canonical_u64(Self::RC12_MID[round_ctr]); + state[0] = Self::sbox_p(state[0]); + Self::matmul_internal(state, &Self::MAT_DIAG12_M_1); + } + + // ------------- Same functions as above but for field extensions of `Self`. + #[inline] + fn matmul_external_field, const D: usize>( + state: &mut [F], + ) { + // Applying cheap 4x4 MDS matrix to each 4-element part of the state + Self::matmul_m4_field(state); + + // Applying second cheap matrix for t > 4 + let t4: usize = SPONGE_WIDTH / 4; + let mut stored = [F::ZERO; 4]; + for l in 0..4 { + stored[l] = state[l]; + for j in 1..t4 { + stored[l] += state[4 * j + l]; + } + } + for i in 0..state.len() { + state[i] += stored[i % 4]; + } + } + fn matmul_m4_field, const D: usize>(state: &mut [F]) { + let t4 = SPONGE_WIDTH / 4; + + for i in 0..t4 { + let start_index = i * 4; + let mut t_0 = state[start_index]; + + t_0 = t_0.add(state[start_index + 1]); + let mut t_1 = state[start_index + 2]; + + t_1 = t_1.add(state[start_index + 3]); + let mut t_2 = t_1; + + t_2 = t_2.multiply_accumulate(state[start_index + 1], F::TWO); + + let mut t_3 = t_0; + + t_3 = t_3.multiply_accumulate(state[start_index + 3], F::TWO); + let mut t_4 = t_3; + + t_4 = t_4.multiply_accumulate(t_1, F::TWO.double()); + + let mut t_5 = t_2; + + t_5 = t_5.multiply_accumulate(t_0, F::TWO.double()); + + let t_6 = t_3.add(t_5); + + let t_7 = t_2.add(t_4); + + state[start_index] = t_6; + state[start_index + 1] = t_5; + state[start_index + 2] = t_7; + state[start_index + 3] = t_4; + } + } + #[inline] + fn matmul_internal_field, const D: usize>( + input: &mut [F], + mat_internal_diag_m_1: &[u64], + ) { + let sum: F = input.iter().copied().sum(); + + for (input_i, &diag_m1) in input.iter_mut().zip(mat_internal_diag_m_1.iter()) { + let diag = F::from_canonical_u64(diag_m1); + *input_i = *input_i * diag + sum; + } + } + + fn constant_layer_field, const D: usize>( + state: &mut [F; SPONGE_WIDTH], + round_ctr: usize, + ) { + let ofs = round_ctr * SPONGE_WIDTH; + for i in 0..SPONGE_WIDTH { + let round_constant = Self::RC12[ofs + i]; + state[i] += F::from_canonical_u64(round_constant); + } + } + fn sbox_layer_field, const D: usize>( + state: &mut [F; SPONGE_WIDTH], + ) { + for i in 0..SPONGE_WIDTH { + state[i] = Self::sbox_p(state[i]); + } + } + + //---------- Same functions for circuit (recursion) ----------- + + fn matmul_m4_circuit( + builder: &mut CircuitBuilder, + state: &mut [ExtensionTarget; SPONGE_WIDTH], + ) where + Self: RichField + Extendable, + { + for i in 0..3 { + let start_index = i * 4; + let t_0 = builder.mul_const_add_extension(Self::ONE, state[start_index], state[start_index + 1]); + let t_1 = + builder.mul_const_add_extension(Self::ONE, state[start_index + 2], state[start_index + 3]); + let t_2 = builder.mul_const_add_extension(Self::TWO, state[start_index + 1], t_1); + let t_3 = builder.mul_const_add_extension(Self::TWO, state[start_index + 3], t_0); + let t_4 = builder.mul_const_add_extension(Self::TWO.double(), t_1, t_3); + let t_5 = builder.mul_const_add_extension(Self::TWO.double(), t_0, t_2); + let t_6 = builder.mul_const_add_extension(Self::ONE, t_3, t_5); + let t_7 = builder.mul_const_add_extension(Self::ONE, t_2, t_4); + + state[start_index] = t_6; + state[start_index + 1] = t_5; + state[start_index + 2] = t_7; + state[start_index + 3] = t_4; + } + } + + fn matmul_external_circuit( + builder: &mut CircuitBuilder, + state: &mut [ExtensionTarget; SPONGE_WIDTH], + ) -> [ExtensionTarget; SPONGE_WIDTH] + where + Self: RichField + Extendable, + { + Self::matmul_m4_circuit(builder, state); + + let t4: usize = SPONGE_WIDTH / 4; + let mut stored = [builder.zero_extension(); 4]; + + for l in 0..4 { + let mut sum = state[l]; + for j in 1..t4 { + let idx = 4 * j + l; + sum = builder.add_extension(sum, state[idx]); + } + stored[l] = sum; + } + + let result = state + .iter() + .enumerate() + .map(|(i, &val)| { + let stored_idx = i % 4; + builder.add_extension(val, stored[stored_idx]) + }) + .collect::>(); + + result.try_into().unwrap_or_else(|v: Vec>| { + panic!("Expected a Vec of length {}", SPONGE_WIDTH) + }) + } + + fn constant_layer_circuit( + builder: &mut CircuitBuilder, + state: &mut [ExtensionTarget; SPONGE_WIDTH], + rc_index: usize, + ) where + Self: RichField + Extendable, + { + let ofs = rc_index * SPONGE_WIDTH; + for i in 0..SPONGE_WIDTH { + let round_constant = Self::Extension::from_canonical_u64(Self::RC12[ofs + i]); + let round_constant = builder.constant_extension(round_constant); + state[i] = builder.add_extension(state[i], round_constant); + } + } + + fn sbox_layer_circuit( + builder: &mut CircuitBuilder, + state: &mut [ExtensionTarget; SPONGE_WIDTH], + ) where + Self: RichField + Extendable, + { + for i in 0..SPONGE_WIDTH { + state[i] = builder.exp_u64_extension(state[i], DEGREE as u64); + } + } + + fn sbox_p_circuit( + builder: &mut CircuitBuilder, + state: ExtensionTarget, + ) -> ExtensionTarget + where + Self: RichField + Extendable, + { + builder.exp_u64_extension(state, DEGREE as u64) + } + + fn matmul_internal_circuit( + builder: &mut CircuitBuilder, + state: &mut [ExtensionTarget; SPONGE_WIDTH], + ) where + Self: RichField + Extendable, + { + let sum = builder.add_many_extension(state.clone()); + + for (i, input_i) in state.iter_mut().enumerate() { + let constant = Self::from_canonical_u64(Self::MAT_DIAG12_M_1[i]); + + *input_i = builder.mul_const_add_extension(constant, *input_i, sum); + } + } + +} + +#[derive(Default, Clone, Copy, Debug, PartialEq)] +pub struct Poseidon2Permutation { + state: [T; SPONGE_WIDTH], +} + +impl AsRef<[T]> for Poseidon2Permutation { + fn as_ref(&self) -> &[T] { + &self.state + } +} + +impl Eq for Poseidon2Permutation {} + +trait Permuter: Sized { + fn permute(input: [Self; SPONGE_WIDTH]) -> [Self; SPONGE_WIDTH]; +} + +impl Permuter for F { + fn permute(input: [Self; SPONGE_WIDTH]) -> [Self; SPONGE_WIDTH] { + ::poseidon2(input) + } +} + +impl Permuter for Target { + fn permute(_input: [Self; SPONGE_WIDTH]) -> [Self; SPONGE_WIDTH] { + panic!("Call `permute_swapped()` instead of `permute()`"); + } +} + +impl PlonkyPermutation + for Poseidon2Permutation +{ + const RATE: usize = 8; + const WIDTH: usize = SPONGE_WIDTH; + + fn new>(elts: I) -> Self { + let mut perm = Self { + state: [T::default(); SPONGE_WIDTH], + }; + perm.set_from_iter(elts, 0); + perm + } + + fn set_elt(&mut self, elt: T, idx: usize) { + self.state[idx] = elt; + } + + fn set_from_slice(&mut self, elts: &[T], start_idx: usize) { + let begin = start_idx; + let end = start_idx + elts.len(); + self.state[begin..end].copy_from_slice(elts); + } + + fn set_from_iter>(&mut self, elts: I, start_idx: usize) { + for (s, e) in self.state[start_idx..].iter_mut().zip(elts) { + *s = e; + } + } + + fn permute(&mut self) { + self.state = T::permute(self.state); + } + + fn squeeze(&self) -> &[T] { + &self.state[..Self::RATE] + } +} + +/// Poseidon2 hash function. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub struct Poseidon2Hash; +impl Hasher for Poseidon2Hash { + const HASH_SIZE: usize = 4 * 8; + type Hash = HashOut; + type Permutation = Poseidon2Permutation; + + fn hash_no_pad(input: &[F]) -> Self::Hash { + hash_n_to_hash_no_pad::(input) + } + + fn two_to_one(left: Self::Hash, right: Self::Hash) -> Self::Hash { + compress::(left, right) + } +} + +impl AlgebraicHasher for Poseidon2Hash { + type AlgebraicPermutation = Poseidon2Permutation; + + fn permute_swapped( + inputs: Self::AlgebraicPermutation, + swap: BoolTarget, + builder: &mut CircuitBuilder, + ) -> Self::AlgebraicPermutation + where + F: RichField + Extendable, + { + let gate_type = Poseidon2Gate::::new(); + let gate = builder.add_gate(gate_type, vec![]); + + let swap_wire = Poseidon2Gate::::WIRE_SWAP; + let swap_wire = Target::wire(gate, swap_wire); + builder.connect(swap.target, swap_wire); + + // Route input wires. + let inputs = inputs.as_ref(); + for i in 0..SPONGE_WIDTH { + let in_wire = Poseidon2Gate::::wire_input(i); + let in_wire = Target::wire(gate, in_wire); + builder.connect(inputs[i], in_wire); + } + + // Collect output wires. + Self::AlgebraicPermutation::new( + (0..SPONGE_WIDTH).map(|i| Target::wire(gate, Poseidon2Gate::::wire_output(i))), + ) + } +} + +#[cfg(test)] +pub(crate) mod test_helpers { + + use crate::poseidon2_hash::poseidon2::{Poseidon2, SPONGE_WIDTH}; + + pub(crate) fn check_test_vectors(test_vectors: Vec<([u64; SPONGE_WIDTH], [u64; SPONGE_WIDTH])>) + where + F: Poseidon2, + { + for (input_, expected_output_) in test_vectors.into_iter() { + let mut input = [F::ZERO; SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { + input[i] = F::from_canonical_u64(input_[i]); + } + let output = F::poseidon2(input); + for i in 0..SPONGE_WIDTH { + let ex_output = F::from_canonical_u64(expected_output_[i]); + assert_eq!(output[i], ex_output); + } + } + } +} + +#[cfg(test)] +pub(crate) mod test_consistency { + use plonky2::hash::hashing::PlonkyPermutation; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use crate::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Permutation, SPONGE_WIDTH}; + use plonky2_field::goldilocks_field::GoldilocksField as F; + use plonky2_field::types::Field; + + #[test] + pub(crate) fn p2new_check_con() + { + let mut input = [F::ZERO; SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { + input[i] = F::from_canonical_u64(i as u64); + } + let output = F::poseidon2(input); + for i in 0..SPONGE_WIDTH { + println!("input {} = {}", i, input[i]); + } + for i in 0..SPONGE_WIDTH { + println!("out {} = {}", i, output[i]); + } + } +} diff --git a/plonky2_poseidon2/src/poseidon2_hash/poseidon2_goldilocks.rs b/plonky2_poseidon2/src/poseidon2_hash/poseidon2_goldilocks.rs new file mode 100644 index 0000000..8203a85 --- /dev/null +++ b/plonky2_poseidon2/src/poseidon2_hash/poseidon2_goldilocks.rs @@ -0,0 +1,176 @@ +//! Implementations for Poseidon over Goldilocks field of width 12. +//! +//! These contents of the implementations is consistent with that in: +//! https://github.com/HorizenLabs/poseidon2 + +use plonky2_field::goldilocks_field::GoldilocksField; +use crate::poseidon2_hash::poseidon2::{Poseidon2, FULL_ROUND_END, PARTIAL_ROUNDS, SPONGE_WIDTH}; + +impl Poseidon2 for GoldilocksField { + const MAT_DIAG12_M_1: [u64; SPONGE_WIDTH] = [ + 0xc3b6c08e23ba9300, + 0xd84b5de94a324fb6, + 0x0d0c371c5b35b84f, + 0x7964f570e7188037, + 0x5daf18bbd996604b, + 0x6743bc47b9595257, + 0x5528b9362c59bb70, + 0xac45e25b7127b68b, + 0xa2077d7dfbb606b5, + 0xf3faac6faee378ae, + 0x0c6388b51545e883, + 0xd27dbb6944917b60, + ]; + + const RC12: [u64; SPONGE_WIDTH * FULL_ROUND_END] = [ + 0x13dcf33aba214f46, + 0x30b3b654a1da6d83, + 0x1fc634ada6159b56, + 0x937459964dc03466, + 0xedd2ef2ca7949924, + 0xede9affde0e22f68, + 0x8515b9d6bac9282d, + 0x6b5c07b4e9e900d8, + 0x1ec66368838c8a08, + 0x9042367d80d1fbab, + 0x400283564a3c3799, + 0x4a00be0466bca75e, + 0x7913beee58e3817f, + 0xf545e88532237d90, + 0x22f8cb8736042005, + 0x6f04990e247a2623, + 0xfe22e87ba37c38cd, + 0xd20e32c85ffe2815, + 0x117227674048fe73, + 0x4e9fb7ea98a6b145, + 0xe0866c232b8af08b, + 0x00bbc77916884964, + 0x7031c0fb990d7116, + 0x240a9e87cf35108f, + 0x2e6363a5a12244b3, + 0x5e1c3787d1b5011c, + 0x4132660e2a196e8b, + 0x3a013b648d3d4327, + 0xf79839f49888ea43, + 0xfe85658ebafe1439, + 0xb6889825a14240bd, + 0x578453605541382b, + 0x4508cda8f6b63ce9, + 0x9c3ef35848684c91, + 0x0812bde23c87178c, + 0xfe49638f7f722c14, + 0x8e3f688ce885cbf5, + 0xb8e110acf746a87d, + 0xb4b2e8973a6dabef, + 0x9e714c5da3d462ec, + 0x6438f9033d3d0c15, + 0x24312f7cf1a27199, + 0x23f843bb47acbf71, + 0x9183f11a34be9f01, + 0x839062fbb9d45dbf, + 0x24b56e7e6c2e43fa, + 0xe1683da61c962a72, + 0xa95c63971a19bfa7, + 0xc68be7c94882a24d, + 0xaf996d5d5cdaedd9, + 0x9717f025e7daf6a5, + 0x6436679e6e7216f4, + 0x8a223d99047af267, + 0xbb512e35a133ba9a, + 0xfbbf44097671aa03, + 0xf04058ebf6811e61, + 0x5cca84703fac7ffb, + 0x9b55c7945de6469f, + 0x8e05bf09808e934f, + 0x2ea900de876307d7, + 0x7748fff2b38dfb89, + 0x6b99a676dd3b5d81, + 0xac4bb7c627cf7c13, + 0xadb6ebe5e9e2f5ba, + 0x2d33378cafa24ae3, + 0x1e5b73807543f8c2, + 0x09208814bfebb10f, + 0x782e64b6bb5b93dd, + 0xadd5a48eac90b50f, + 0xadd4c54c736ea4b1, + 0xd58dbb86ed817fd8, + 0x6d5ed1a533f34ddd, + 0x28686aa3e36b7cb9, + 0x591abd3476689f36, + 0x047d766678f13875, + 0xa2a11112625f5b49, + 0x21fd10a3f8304958, + 0xf9b40711443b0280, + 0xd2697eb8b2bde88e, + 0x3493790b51731b3f, + 0x11caf9dd73764023, + 0x7acfb8f72878164e, + 0x744ec4db23cefc26, + 0x1e00e58f422c6340, + 0x21dd28d906a62dda, + 0xf32a46ab5f465b5f, + 0xbfce13201f3f7e6b, + 0xf30d2e7adb5304e2, + 0xecdf4ee4abad48e9, + 0xf94e82182d395019, + 0x4ee52e3744d887c5, + 0xa1341c7cac0083b2, + 0x2302fb26c30c834a, + 0xaea3c587273bf7d3, + 0xf798e24961823ec7, + 0x962deba3e9a2cd94, + ]; + + const RC12_MID: [u64; PARTIAL_ROUNDS] = [ + 0x4adf842aa75d4316, + 0xf8fbb871aa4ab4eb, + 0x68e85b6eb2dd6aeb, + 0x07a0b06b2d270380, + 0xd94e0228bd282de4, + 0x8bdd91d3250c5278, + 0x209c68b88bba778f, + 0xb5e18cdab77f3877, + 0xb296a3e808da93fa, + 0x8370ecbda11a327e, + 0x3f9075283775dad8, + 0xb78095bb23c6aa84, + 0x3f36b9fe72ad4e5f, + 0x69bc96780b10b553, + 0x3f1d341f2eb7b881, + 0x4e939e9815838818, + 0xda366b3ae2a31604, + 0xbc89db1e7287d509, + 0x6102f411f9ef5659, + 0x58725c5e7ac1f0ab, + 0x0df5856c798883e7, + 0xf7bb62a8da4c961b, + ]; +} + +#[cfg(test)] +mod tests { + + use plonky2_field::goldilocks_field::GoldilocksField as F; + use plonky2_field::types::{Field, PrimeField64}; + use crate::poseidon2_hash::poseidon2::test_helpers::check_test_vectors; + + #[test] + fn p2new_test_vectors() { + // Test inputs are: + // 1. range 0..WIDTH + // expected output calculated with reference implementation here: + // https://github.com/HorizenLabs/poseidon2 + + let neg_one: u64 = F::NEG_ONE.to_canonical_u64(); + + #[rustfmt::skip] + let test_vectors12: Vec<([u64; 12], [u64; 12])> = vec![ + ([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, ], + [0x01eaef96bdf1c0c1, 0x1f0d2cc525b2540c, 0x6282c1dfe1e0358d, 0xe780d721f698e1e6, + 0x280c0b6f753d833b, 0x1b942dd5023156ab, 0x43f0df3fcccb8398, 0xe8e8190585489025, + 0x56bdbf72f77ada22, 0x7911c32bf9dcd705, 0xec467926508fbe67, 0x6a50450ddf85a6ed,]), + ]; + + check_test_vectors::(test_vectors12); + } +}