add key_compress and refactor

This commit is contained in:
M Alghazwi 2024-11-03 11:48:35 +01:00
parent c725043c4d
commit 4d723f8c7f
6 changed files with 77 additions and 42 deletions

View File

@ -0,0 +1,56 @@
use plonky2::hash::hash_types::{HashOut, HashOutTarget, RichField, NUM_HASH_OUT_ELTS};
use plonky2::hash::hashing::PlonkyPermutation;
use plonky2::iop::target::Target;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::config::{AlgebraicHasher, Hasher};
use plonky2_field::extension::Extendable;
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
/// Compression function which takes two 256 bit inputs (HashOut) and u64 key (which is converted to field element in the function)
/// and returns a 256 bit output (HashOut).
pub fn key_compress<F: RichField, H:Hasher<F> >(x: HashOut<F>, y: HashOut<F>, key: u64) -> HashOut<F> {
debug_assert_eq!(x.elements.len(), NUM_HASH_OUT_ELTS);
debug_assert_eq!(y.elements.len(), NUM_HASH_OUT_ELTS);
let key_field = F::from_canonical_u64(key);
let mut perm = H::Permutation::new(core::iter::repeat(F::ZERO));
perm.set_from_slice(&x.elements, 0);
perm.set_from_slice(&y.elements, NUM_HASH_OUT_ELTS);
perm.set_elt(key_field,NUM_HASH_OUT_ELTS*2);
perm.permute();
HashOut {
elements: perm.squeeze()[..NUM_HASH_OUT_ELTS].try_into().unwrap(),
}
}
/// same as above but in-circuit
pub fn key_compress_circuit<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
H: AlgebraicHasher<F>,
>(
builder: &mut CircuitBuilder<F, D>,
x: Vec<Target>,
y: Vec<Target>,
key: Target,
) -> HashOutTarget {
let zero = builder.zero();
let mut state = H::AlgebraicPermutation::new(core::iter::repeat(zero));
state.set_from_slice(&x, 0);
state.set_from_slice(&y, NUM_HASH_OUT_ELTS);
state.set_elt(key, NUM_HASH_OUT_ELTS*2);
state = builder.permute::<H>(state);
HashOutTarget {
elements: state.squeeze()[..NUM_HASH_OUT_ELTS].try_into().unwrap(),
}
}

View File

@ -4,3 +4,4 @@ pub mod prove_single_cell;
pub mod sample_cells;
pub mod utils;
pub mod params;
pub mod keyed_compress;

View File

@ -270,9 +270,9 @@ impl<
// Create virtual targets
let mut leaf = (0..N_FIELD_ELEMS_PER_CELL).map(|_| builder.add_virtual_target()).collect::<Vec<_>>();
let mut perm_inputs:Vec<Target>= Vec::new();
perm_inputs.extend_from_slice(&leaf);
let leaf_hash = builder.hash_n_to_hash_no_pad::<HF>(perm_inputs);
let mut hash_inputs:Vec<Target>= Vec::new();
hash_inputs.extend_from_slice(&leaf);
let leaf_hash = builder.hash_n_to_hash_no_pad::<HF>(hash_inputs);
// path bits (binary decomposition of leaf_index)
let mut block_path_bits = (0..BOT_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();

View File

@ -18,6 +18,7 @@ use plonky2::plonk::proof::{Proof, ProofWithPublicInputs};
use std::marker::PhantomData;
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
use serde::Serialize;
use crate::circuits::keyed_compress::key_compress_circuit;
use crate::circuits::params::HF;
use crate::circuits::utils::usize_to_bits_le_padded;
@ -183,12 +184,7 @@ impl<
right.push( builder.select(bit, state.elements[i], sibling.elements[i]));
}
// hash left, right, and key
let mut perm_inputs:Vec<Target>= Vec::new();
perm_inputs.extend_from_slice(&left);
perm_inputs.extend_from_slice(&right);
perm_inputs.push(key);
state = builder.hash_n_to_hash_no_pad::<HF>(perm_inputs);
state = key_compress_circuit::<F,D,HF>(builder,left,right,key);
i += 1;
}

View File

@ -291,9 +291,9 @@ impl<
// cell data targets
let mut data_i = (0..N_FIELD_ELEMS_PER_CELL).map(|_| builder.add_virtual_target()).collect::<Vec<_>>();
let mut perm_inputs:Vec<Target>= Vec::new();
perm_inputs.extend_from_slice(&data_i);
let data_i_hash = builder.hash_n_to_hash_no_pad::<HF>(perm_inputs);
let mut hash_inputs:Vec<Target>= Vec::new();
hash_inputs.extend_from_slice(&data_i);
let data_i_hash = builder.hash_n_to_hash_no_pad::<HF>(hash_inputs);
// counter constant
let ctr_target = builder.constant(F::from_canonical_u64((i+1) as u64));
let mut ctr = builder.add_virtual_hash();

View File

@ -10,7 +10,8 @@ use plonky2::hash::poseidon::PoseidonHash;
use plonky2::plonk::config::Hasher;
use std::ops::Shr;
use plonky2_field::types::Field;
use crate::circuits::keyed_compress::key_compress;
use crate::circuits::params::HF;
// Constants for the keys used in compression
pub const KEY_NONE: u64 = 0x0;
@ -18,11 +19,6 @@ pub const KEY_BOTTOM_LAYER: u64 = 0x1;
pub const KEY_ODD: u64 = 0x2;
pub const KEY_ODD_AND_BOTTOM_LAYER: u64 = 0x3;
// hash function used. this is hackish way of doing it because
// H::Hash is not consistent with HashOut<F> and causing a lot of headache
// will look into this later.
type HF = PoseidonHash;
/// Merkle tree struct, containing the layers, compression function, and zero hash.
#[derive(Clone)]
pub struct MerkleTree<F: RichField> {
@ -92,16 +88,6 @@ impl<F: RichField> MerkleTree<F> {
}
}
/// compress input (x and y) with key using the define HF hash function
fn key_compress<F: RichField>(x: HashOut<F>, y: HashOut<F>, key: u64) -> HashOut<F> {
let key_field = F::from_canonical_u64(key);
let mut inputs = Vec::new();
inputs.extend_from_slice(&x.elements);
inputs.extend_from_slice(&y.elements);
inputs.push(key_field);
HF::hash_no_pad(&inputs) // TODO: double-check this function
}
/// Build the Merkle tree layers.
fn merkle_tree_worker<F: RichField>(
xs: &[HashOut<F>],
@ -121,7 +107,7 @@ fn merkle_tree_worker<F: RichField>(
for i in 0..halfn {
let key = if is_bottom_layer { KEY_BOTTOM_LAYER } else { KEY_NONE };
let h = key_compress::<F>(xs[2 * i], xs[2 * i + 1], key);
let h = key_compress::<F, HF>(xs[2 * i], xs[2 * i + 1], key);
ys.push(h);
}
@ -131,7 +117,7 @@ fn merkle_tree_worker<F: RichField>(
} else {
KEY_ODD
};
let h = key_compress::<F>(xs[n], zero, key);
let h = key_compress::<F, HF>(xs[n], zero, key);
ys.push(h);
}
@ -169,14 +155,14 @@ impl<F: RichField> MerkleProof<F> {
let odd_index = (j & 1) != 0;
if odd_index {
// The index of the child is odd
h = key_compress::<F>(*p, h, bottom_flag);
h = key_compress::<F,HF>(*p, h, bottom_flag);
} else {
if j == m - 1 {
// Single child -> so odd node
h = key_compress::<F>(h, *p, bottom_flag + 2);
h = key_compress::<F,HF>(h, *p, bottom_flag + 2);
} else {
// Even node
h = key_compress::<F>(h, *p, bottom_flag);
h = key_compress::<F,HF>(h, *p, bottom_flag);
}
}
bottom_flag = KEY_NONE;
@ -207,9 +193,9 @@ impl<F: RichField> MerkleProof<F> {
let key = bottom + (2 * (odd as u64));
let odd_index = path_bits[i];
if odd_index {
h = key_compress::<F>(*p, h, key);
h = key_compress::<F,HF>(*p, h, key);
} else {
h = key_compress::<F>(h, *p, key);
h = key_compress::<F,HF>(h, *p, key);
}
i += 1;
}
@ -245,6 +231,7 @@ fn compute_is_last(path_bits: Vec<bool>, last_bits: Vec<bool>) -> Vec<bool> {
mod tests {
use super::*;
use plonky2::field::types::Field;
use crate::circuits::keyed_compress::key_compress;
// types used in all tests
type F = GoldilocksField;
@ -255,12 +242,7 @@ mod tests {
y: HashOut<F>,
key: u64,
) -> HashOut<F> {
let key_field = F::from_canonical_u64(key);
let mut inputs = Vec::new();
inputs.extend_from_slice(&x.elements);
inputs.extend_from_slice(&y.elements);
inputs.push(key_field);
PoseidonHash::hash_no_pad(&inputs)
key_compress::<F,HF>(x,y,key)
}
fn make_tree(