add key_compress and refactor
This commit is contained in:
parent
c725043c4d
commit
4d723f8c7f
|
@ -0,0 +1,56 @@
|
||||||
|
use plonky2::hash::hash_types::{HashOut, HashOutTarget, RichField, NUM_HASH_OUT_ELTS};
|
||||||
|
use plonky2::hash::hashing::PlonkyPermutation;
|
||||||
|
use plonky2::iop::target::Target;
|
||||||
|
use plonky2::plonk::circuit_builder::CircuitBuilder;
|
||||||
|
use plonky2::plonk::config::{AlgebraicHasher, Hasher};
|
||||||
|
use plonky2_field::extension::Extendable;
|
||||||
|
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
|
||||||
|
|
||||||
|
/// Compression function which takes two 256 bit inputs (HashOut) and u64 key (which is converted to field element in the function)
|
||||||
|
/// and returns a 256 bit output (HashOut).
|
||||||
|
pub fn key_compress<F: RichField, H:Hasher<F> >(x: HashOut<F>, y: HashOut<F>, key: u64) -> HashOut<F> {
|
||||||
|
|
||||||
|
debug_assert_eq!(x.elements.len(), NUM_HASH_OUT_ELTS);
|
||||||
|
debug_assert_eq!(y.elements.len(), NUM_HASH_OUT_ELTS);
|
||||||
|
|
||||||
|
let key_field = F::from_canonical_u64(key);
|
||||||
|
|
||||||
|
let mut perm = H::Permutation::new(core::iter::repeat(F::ZERO));
|
||||||
|
perm.set_from_slice(&x.elements, 0);
|
||||||
|
perm.set_from_slice(&y.elements, NUM_HASH_OUT_ELTS);
|
||||||
|
perm.set_elt(key_field,NUM_HASH_OUT_ELTS*2);
|
||||||
|
|
||||||
|
perm.permute();
|
||||||
|
|
||||||
|
HashOut {
|
||||||
|
elements: perm.squeeze()[..NUM_HASH_OUT_ELTS].try_into().unwrap(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// same as above but in-circuit
|
||||||
|
pub fn key_compress_circuit<
|
||||||
|
F: RichField + Extendable<D> + Poseidon2,
|
||||||
|
const D: usize,
|
||||||
|
H: AlgebraicHasher<F>,
|
||||||
|
>(
|
||||||
|
builder: &mut CircuitBuilder<F, D>,
|
||||||
|
x: Vec<Target>,
|
||||||
|
y: Vec<Target>,
|
||||||
|
key: Target,
|
||||||
|
) -> HashOutTarget {
|
||||||
|
let zero = builder.zero();
|
||||||
|
let mut state = H::AlgebraicPermutation::new(core::iter::repeat(zero));
|
||||||
|
|
||||||
|
state.set_from_slice(&x, 0);
|
||||||
|
state.set_from_slice(&y, NUM_HASH_OUT_ELTS);
|
||||||
|
state.set_elt(key, NUM_HASH_OUT_ELTS*2);
|
||||||
|
|
||||||
|
state = builder.permute::<H>(state);
|
||||||
|
|
||||||
|
HashOutTarget {
|
||||||
|
elements: state.squeeze()[..NUM_HASH_OUT_ELTS].try_into().unwrap(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,3 +4,4 @@ pub mod prove_single_cell;
|
||||||
pub mod sample_cells;
|
pub mod sample_cells;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
pub mod params;
|
pub mod params;
|
||||||
|
pub mod keyed_compress;
|
|
@ -270,9 +270,9 @@ impl<
|
||||||
// Create virtual targets
|
// Create virtual targets
|
||||||
let mut leaf = (0..N_FIELD_ELEMS_PER_CELL).map(|_| builder.add_virtual_target()).collect::<Vec<_>>();
|
let mut leaf = (0..N_FIELD_ELEMS_PER_CELL).map(|_| builder.add_virtual_target()).collect::<Vec<_>>();
|
||||||
|
|
||||||
let mut perm_inputs:Vec<Target>= Vec::new();
|
let mut hash_inputs:Vec<Target>= Vec::new();
|
||||||
perm_inputs.extend_from_slice(&leaf);
|
hash_inputs.extend_from_slice(&leaf);
|
||||||
let leaf_hash = builder.hash_n_to_hash_no_pad::<HF>(perm_inputs);
|
let leaf_hash = builder.hash_n_to_hash_no_pad::<HF>(hash_inputs);
|
||||||
|
|
||||||
// path bits (binary decomposition of leaf_index)
|
// path bits (binary decomposition of leaf_index)
|
||||||
let mut block_path_bits = (0..BOT_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
|
let mut block_path_bits = (0..BOT_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::<Vec<_>>();
|
||||||
|
|
|
@ -18,6 +18,7 @@ use plonky2::plonk::proof::{Proof, ProofWithPublicInputs};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
|
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
use crate::circuits::keyed_compress::key_compress_circuit;
|
||||||
use crate::circuits::params::HF;
|
use crate::circuits::params::HF;
|
||||||
use crate::circuits::utils::usize_to_bits_le_padded;
|
use crate::circuits::utils::usize_to_bits_le_padded;
|
||||||
|
|
||||||
|
@ -183,12 +184,7 @@ impl<
|
||||||
right.push( builder.select(bit, state.elements[i], sibling.elements[i]));
|
right.push( builder.select(bit, state.elements[i], sibling.elements[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
// hash left, right, and key
|
state = key_compress_circuit::<F,D,HF>(builder,left,right,key);
|
||||||
let mut perm_inputs:Vec<Target>= Vec::new();
|
|
||||||
perm_inputs.extend_from_slice(&left);
|
|
||||||
perm_inputs.extend_from_slice(&right);
|
|
||||||
perm_inputs.push(key);
|
|
||||||
state = builder.hash_n_to_hash_no_pad::<HF>(perm_inputs);
|
|
||||||
|
|
||||||
i += 1;
|
i += 1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -291,9 +291,9 @@ impl<
|
||||||
// cell data targets
|
// cell data targets
|
||||||
let mut data_i = (0..N_FIELD_ELEMS_PER_CELL).map(|_| builder.add_virtual_target()).collect::<Vec<_>>();
|
let mut data_i = (0..N_FIELD_ELEMS_PER_CELL).map(|_| builder.add_virtual_target()).collect::<Vec<_>>();
|
||||||
|
|
||||||
let mut perm_inputs:Vec<Target>= Vec::new();
|
let mut hash_inputs:Vec<Target>= Vec::new();
|
||||||
perm_inputs.extend_from_slice(&data_i);
|
hash_inputs.extend_from_slice(&data_i);
|
||||||
let data_i_hash = builder.hash_n_to_hash_no_pad::<HF>(perm_inputs);
|
let data_i_hash = builder.hash_n_to_hash_no_pad::<HF>(hash_inputs);
|
||||||
// counter constant
|
// counter constant
|
||||||
let ctr_target = builder.constant(F::from_canonical_u64((i+1) as u64));
|
let ctr_target = builder.constant(F::from_canonical_u64((i+1) as u64));
|
||||||
let mut ctr = builder.add_virtual_hash();
|
let mut ctr = builder.add_virtual_hash();
|
||||||
|
|
|
@ -10,7 +10,8 @@ use plonky2::hash::poseidon::PoseidonHash;
|
||||||
use plonky2::plonk::config::Hasher;
|
use plonky2::plonk::config::Hasher;
|
||||||
use std::ops::Shr;
|
use std::ops::Shr;
|
||||||
use plonky2_field::types::Field;
|
use plonky2_field::types::Field;
|
||||||
|
use crate::circuits::keyed_compress::key_compress;
|
||||||
|
use crate::circuits::params::HF;
|
||||||
|
|
||||||
// Constants for the keys used in compression
|
// Constants for the keys used in compression
|
||||||
pub const KEY_NONE: u64 = 0x0;
|
pub const KEY_NONE: u64 = 0x0;
|
||||||
|
@ -18,11 +19,6 @@ pub const KEY_BOTTOM_LAYER: u64 = 0x1;
|
||||||
pub const KEY_ODD: u64 = 0x2;
|
pub const KEY_ODD: u64 = 0x2;
|
||||||
pub const KEY_ODD_AND_BOTTOM_LAYER: u64 = 0x3;
|
pub const KEY_ODD_AND_BOTTOM_LAYER: u64 = 0x3;
|
||||||
|
|
||||||
// hash function used. this is hackish way of doing it because
|
|
||||||
// H::Hash is not consistent with HashOut<F> and causing a lot of headache
|
|
||||||
// will look into this later.
|
|
||||||
type HF = PoseidonHash;
|
|
||||||
|
|
||||||
/// Merkle tree struct, containing the layers, compression function, and zero hash.
|
/// Merkle tree struct, containing the layers, compression function, and zero hash.
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct MerkleTree<F: RichField> {
|
pub struct MerkleTree<F: RichField> {
|
||||||
|
@ -92,16 +88,6 @@ impl<F: RichField> MerkleTree<F> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// compress input (x and y) with key using the define HF hash function
|
|
||||||
fn key_compress<F: RichField>(x: HashOut<F>, y: HashOut<F>, key: u64) -> HashOut<F> {
|
|
||||||
let key_field = F::from_canonical_u64(key);
|
|
||||||
let mut inputs = Vec::new();
|
|
||||||
inputs.extend_from_slice(&x.elements);
|
|
||||||
inputs.extend_from_slice(&y.elements);
|
|
||||||
inputs.push(key_field);
|
|
||||||
HF::hash_no_pad(&inputs) // TODO: double-check this function
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build the Merkle tree layers.
|
/// Build the Merkle tree layers.
|
||||||
fn merkle_tree_worker<F: RichField>(
|
fn merkle_tree_worker<F: RichField>(
|
||||||
xs: &[HashOut<F>],
|
xs: &[HashOut<F>],
|
||||||
|
@ -121,7 +107,7 @@ fn merkle_tree_worker<F: RichField>(
|
||||||
|
|
||||||
for i in 0..halfn {
|
for i in 0..halfn {
|
||||||
let key = if is_bottom_layer { KEY_BOTTOM_LAYER } else { KEY_NONE };
|
let key = if is_bottom_layer { KEY_BOTTOM_LAYER } else { KEY_NONE };
|
||||||
let h = key_compress::<F>(xs[2 * i], xs[2 * i + 1], key);
|
let h = key_compress::<F, HF>(xs[2 * i], xs[2 * i + 1], key);
|
||||||
ys.push(h);
|
ys.push(h);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,7 +117,7 @@ fn merkle_tree_worker<F: RichField>(
|
||||||
} else {
|
} else {
|
||||||
KEY_ODD
|
KEY_ODD
|
||||||
};
|
};
|
||||||
let h = key_compress::<F>(xs[n], zero, key);
|
let h = key_compress::<F, HF>(xs[n], zero, key);
|
||||||
ys.push(h);
|
ys.push(h);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,14 +155,14 @@ impl<F: RichField> MerkleProof<F> {
|
||||||
let odd_index = (j & 1) != 0;
|
let odd_index = (j & 1) != 0;
|
||||||
if odd_index {
|
if odd_index {
|
||||||
// The index of the child is odd
|
// The index of the child is odd
|
||||||
h = key_compress::<F>(*p, h, bottom_flag);
|
h = key_compress::<F,HF>(*p, h, bottom_flag);
|
||||||
} else {
|
} else {
|
||||||
if j == m - 1 {
|
if j == m - 1 {
|
||||||
// Single child -> so odd node
|
// Single child -> so odd node
|
||||||
h = key_compress::<F>(h, *p, bottom_flag + 2);
|
h = key_compress::<F,HF>(h, *p, bottom_flag + 2);
|
||||||
} else {
|
} else {
|
||||||
// Even node
|
// Even node
|
||||||
h = key_compress::<F>(h, *p, bottom_flag);
|
h = key_compress::<F,HF>(h, *p, bottom_flag);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bottom_flag = KEY_NONE;
|
bottom_flag = KEY_NONE;
|
||||||
|
@ -207,9 +193,9 @@ impl<F: RichField> MerkleProof<F> {
|
||||||
let key = bottom + (2 * (odd as u64));
|
let key = bottom + (2 * (odd as u64));
|
||||||
let odd_index = path_bits[i];
|
let odd_index = path_bits[i];
|
||||||
if odd_index {
|
if odd_index {
|
||||||
h = key_compress::<F>(*p, h, key);
|
h = key_compress::<F,HF>(*p, h, key);
|
||||||
} else {
|
} else {
|
||||||
h = key_compress::<F>(h, *p, key);
|
h = key_compress::<F,HF>(h, *p, key);
|
||||||
}
|
}
|
||||||
i += 1;
|
i += 1;
|
||||||
}
|
}
|
||||||
|
@ -245,6 +231,7 @@ fn compute_is_last(path_bits: Vec<bool>, last_bits: Vec<bool>) -> Vec<bool> {
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use plonky2::field::types::Field;
|
use plonky2::field::types::Field;
|
||||||
|
use crate::circuits::keyed_compress::key_compress;
|
||||||
|
|
||||||
// types used in all tests
|
// types used in all tests
|
||||||
type F = GoldilocksField;
|
type F = GoldilocksField;
|
||||||
|
@ -255,12 +242,7 @@ mod tests {
|
||||||
y: HashOut<F>,
|
y: HashOut<F>,
|
||||||
key: u64,
|
key: u64,
|
||||||
) -> HashOut<F> {
|
) -> HashOut<F> {
|
||||||
let key_field = F::from_canonical_u64(key);
|
key_compress::<F,HF>(x,y,key)
|
||||||
let mut inputs = Vec::new();
|
|
||||||
inputs.extend_from_slice(&x.elements);
|
|
||||||
inputs.extend_from_slice(&y.elements);
|
|
||||||
inputs.push(key_field);
|
|
||||||
PoseidonHash::hash_no_pad(&inputs)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_tree(
|
fn make_tree(
|
||||||
|
|
Loading…
Reference in New Issue