diff --git a/codex-plonky2-circuits/src/circuits/keyed_compress.rs b/codex-plonky2-circuits/src/circuits/keyed_compress.rs new file mode 100644 index 0000000..2cd6c9e --- /dev/null +++ b/codex-plonky2-circuits/src/circuits/keyed_compress.rs @@ -0,0 +1,56 @@ +use plonky2::hash::hash_types::{HashOut, HashOutTarget, RichField, NUM_HASH_OUT_ELTS}; +use plonky2::hash::hashing::PlonkyPermutation; +use plonky2::iop::target::Target; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::config::{AlgebraicHasher, Hasher}; +use plonky2_field::extension::Extendable; +use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; + +/// Compression function which takes two 256 bit inputs (HashOut) and u64 key (which is converted to field element in the function) +/// and returns a 256 bit output (HashOut). +pub fn key_compress >(x: HashOut, y: HashOut, key: u64) -> HashOut { + + debug_assert_eq!(x.elements.len(), NUM_HASH_OUT_ELTS); + debug_assert_eq!(y.elements.len(), NUM_HASH_OUT_ELTS); + + let key_field = F::from_canonical_u64(key); + + let mut perm = H::Permutation::new(core::iter::repeat(F::ZERO)); + perm.set_from_slice(&x.elements, 0); + perm.set_from_slice(&y.elements, NUM_HASH_OUT_ELTS); + perm.set_elt(key_field,NUM_HASH_OUT_ELTS*2); + + perm.permute(); + + HashOut { + elements: perm.squeeze()[..NUM_HASH_OUT_ELTS].try_into().unwrap(), + } +} + +/// same as above but in-circuit +pub fn key_compress_circuit< + F: RichField + Extendable + Poseidon2, + const D: usize, + H: AlgebraicHasher, +>( + builder: &mut CircuitBuilder, + x: Vec, + y: Vec, + key: Target, +) -> HashOutTarget { + let zero = builder.zero(); + let mut state = H::AlgebraicPermutation::new(core::iter::repeat(zero)); + + state.set_from_slice(&x, 0); + state.set_from_slice(&y, NUM_HASH_OUT_ELTS); + state.set_elt(key, NUM_HASH_OUT_ELTS*2); + + state = builder.permute::(state); + + HashOutTarget { + elements: state.squeeze()[..NUM_HASH_OUT_ELTS].try_into().unwrap(), + } +} + + + diff --git a/codex-plonky2-circuits/src/circuits/mod.rs b/codex-plonky2-circuits/src/circuits/mod.rs index 41bb46e..523eb76 100644 --- a/codex-plonky2-circuits/src/circuits/mod.rs +++ b/codex-plonky2-circuits/src/circuits/mod.rs @@ -3,4 +3,5 @@ pub mod safe_tree_circuit; pub mod prove_single_cell; pub mod sample_cells; pub mod utils; -pub mod params; \ No newline at end of file +pub mod params; +pub mod keyed_compress; \ No newline at end of file diff --git a/codex-plonky2-circuits/src/circuits/prove_single_cell.rs b/codex-plonky2-circuits/src/circuits/prove_single_cell.rs index da9eba9..5fe9b7c 100644 --- a/codex-plonky2-circuits/src/circuits/prove_single_cell.rs +++ b/codex-plonky2-circuits/src/circuits/prove_single_cell.rs @@ -270,9 +270,9 @@ impl< // Create virtual targets let mut leaf = (0..N_FIELD_ELEMS_PER_CELL).map(|_| builder.add_virtual_target()).collect::>(); - let mut perm_inputs:Vec= Vec::new(); - perm_inputs.extend_from_slice(&leaf); - let leaf_hash = builder.hash_n_to_hash_no_pad::(perm_inputs); + let mut hash_inputs:Vec= Vec::new(); + hash_inputs.extend_from_slice(&leaf); + let leaf_hash = builder.hash_n_to_hash_no_pad::(hash_inputs); // path bits (binary decomposition of leaf_index) let mut block_path_bits = (0..BOT_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); diff --git a/codex-plonky2-circuits/src/circuits/safe_tree_circuit.rs b/codex-plonky2-circuits/src/circuits/safe_tree_circuit.rs index 74f2f7c..d74a935 100644 --- a/codex-plonky2-circuits/src/circuits/safe_tree_circuit.rs +++ b/codex-plonky2-circuits/src/circuits/safe_tree_circuit.rs @@ -18,6 +18,7 @@ use plonky2::plonk::proof::{Proof, ProofWithPublicInputs}; use std::marker::PhantomData; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; use serde::Serialize; +use crate::circuits::keyed_compress::key_compress_circuit; use crate::circuits::params::HF; use crate::circuits::utils::usize_to_bits_le_padded; @@ -183,12 +184,7 @@ impl< right.push( builder.select(bit, state.elements[i], sibling.elements[i])); } - // hash left, right, and key - let mut perm_inputs:Vec= Vec::new(); - perm_inputs.extend_from_slice(&left); - perm_inputs.extend_from_slice(&right); - perm_inputs.push(key); - state = builder.hash_n_to_hash_no_pad::(perm_inputs); + state = key_compress_circuit::(builder,left,right,key); i += 1; } diff --git a/codex-plonky2-circuits/src/circuits/sample_cells.rs b/codex-plonky2-circuits/src/circuits/sample_cells.rs index fb675bf..f351105 100644 --- a/codex-plonky2-circuits/src/circuits/sample_cells.rs +++ b/codex-plonky2-circuits/src/circuits/sample_cells.rs @@ -291,9 +291,9 @@ impl< // cell data targets let mut data_i = (0..N_FIELD_ELEMS_PER_CELL).map(|_| builder.add_virtual_target()).collect::>(); - let mut perm_inputs:Vec= Vec::new(); - perm_inputs.extend_from_slice(&data_i); - let data_i_hash = builder.hash_n_to_hash_no_pad::(perm_inputs); + let mut hash_inputs:Vec= Vec::new(); + hash_inputs.extend_from_slice(&data_i); + let data_i_hash = builder.hash_n_to_hash_no_pad::(hash_inputs); // counter constant let ctr_target = builder.constant(F::from_canonical_u64((i+1) as u64)); let mut ctr = builder.add_virtual_hash(); diff --git a/codex-plonky2-circuits/src/merkle_tree/merkle_safe.rs b/codex-plonky2-circuits/src/merkle_tree/merkle_safe.rs index 03d29f8..b23fa20 100644 --- a/codex-plonky2-circuits/src/merkle_tree/merkle_safe.rs +++ b/codex-plonky2-circuits/src/merkle_tree/merkle_safe.rs @@ -10,7 +10,8 @@ use plonky2::hash::poseidon::PoseidonHash; use plonky2::plonk::config::Hasher; use std::ops::Shr; use plonky2_field::types::Field; - +use crate::circuits::keyed_compress::key_compress; +use crate::circuits::params::HF; // Constants for the keys used in compression pub const KEY_NONE: u64 = 0x0; @@ -18,11 +19,6 @@ pub const KEY_BOTTOM_LAYER: u64 = 0x1; pub const KEY_ODD: u64 = 0x2; pub const KEY_ODD_AND_BOTTOM_LAYER: u64 = 0x3; -// hash function used. this is hackish way of doing it because -// H::Hash is not consistent with HashOut and causing a lot of headache -// will look into this later. -type HF = PoseidonHash; - /// Merkle tree struct, containing the layers, compression function, and zero hash. #[derive(Clone)] pub struct MerkleTree { @@ -92,16 +88,6 @@ impl MerkleTree { } } -/// compress input (x and y) with key using the define HF hash function -fn key_compress(x: HashOut, y: HashOut, key: u64) -> HashOut { - let key_field = F::from_canonical_u64(key); - let mut inputs = Vec::new(); - inputs.extend_from_slice(&x.elements); - inputs.extend_from_slice(&y.elements); - inputs.push(key_field); - HF::hash_no_pad(&inputs) // TODO: double-check this function -} - /// Build the Merkle tree layers. fn merkle_tree_worker( xs: &[HashOut], @@ -121,7 +107,7 @@ fn merkle_tree_worker( for i in 0..halfn { let key = if is_bottom_layer { KEY_BOTTOM_LAYER } else { KEY_NONE }; - let h = key_compress::(xs[2 * i], xs[2 * i + 1], key); + let h = key_compress::(xs[2 * i], xs[2 * i + 1], key); ys.push(h); } @@ -131,7 +117,7 @@ fn merkle_tree_worker( } else { KEY_ODD }; - let h = key_compress::(xs[n], zero, key); + let h = key_compress::(xs[n], zero, key); ys.push(h); } @@ -169,14 +155,14 @@ impl MerkleProof { let odd_index = (j & 1) != 0; if odd_index { // The index of the child is odd - h = key_compress::(*p, h, bottom_flag); + h = key_compress::(*p, h, bottom_flag); } else { if j == m - 1 { // Single child -> so odd node - h = key_compress::(h, *p, bottom_flag + 2); + h = key_compress::(h, *p, bottom_flag + 2); } else { // Even node - h = key_compress::(h, *p, bottom_flag); + h = key_compress::(h, *p, bottom_flag); } } bottom_flag = KEY_NONE; @@ -207,9 +193,9 @@ impl MerkleProof { let key = bottom + (2 * (odd as u64)); let odd_index = path_bits[i]; if odd_index { - h = key_compress::(*p, h, key); + h = key_compress::(*p, h, key); } else { - h = key_compress::(h, *p, key); + h = key_compress::(h, *p, key); } i += 1; } @@ -245,6 +231,7 @@ fn compute_is_last(path_bits: Vec, last_bits: Vec) -> Vec { mod tests { use super::*; use plonky2::field::types::Field; + use crate::circuits::keyed_compress::key_compress; // types used in all tests type F = GoldilocksField; @@ -255,12 +242,7 @@ mod tests { y: HashOut, key: u64, ) -> HashOut { - let key_field = F::from_canonical_u64(key); - let mut inputs = Vec::new(); - inputs.extend_from_slice(&x.elements); - inputs.extend_from_slice(&y.elements); - inputs.push(key_field); - PoseidonHash::hash_no_pad(&inputs) + key_compress::(x,y,key) } fn make_tree(