diff --git a/src/bin/bench_recursion.rs b/src/bin/bench_recursion.rs index 117403fc..650bee11 100644 --- a/src/bin/bench_recursion.rs +++ b/src/bin/bench_recursion.rs @@ -5,6 +5,7 @@ use plonky2::field::crandall_field::CrandallField; use plonky2::field::extension_field::Extendable; use plonky2::field::field_types::RichField; use plonky2::fri::FriConfig; +use plonky2::hash::hashing::SPONGE_WIDTH; use plonky2::iop::witness::PartialWitness; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_data::CircuitConfig; @@ -41,7 +42,7 @@ fn bench_prove, const D: usize>() -> Result<()> { let zero = builder.zero(); let zero_ext = builder.zero_extension(); - let mut state = [zero; 12]; + let mut state = [zero; SPONGE_WIDTH]; for _ in 0..10000 { state = builder.permute(state); } diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 3aef44ab..c45be25f 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -103,7 +103,7 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(swap * (swap - F::Extension::ONE)); - let mut state = Vec::with_capacity(12); + let mut state = Vec::with_capacity(WIDTH); for i in 0..4 { let a = vars.local_wires[i]; let b = vars.local_wires[i + 4]; @@ -114,7 +114,7 @@ where let b = vars.local_wires[i]; state.push(a + swap * (b - a)); } - for i in 8..12 { + for i in 8..WIDTH { state.push(vars.local_wires[i]); } @@ -182,7 +182,7 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(swap * (swap - F::ONE)); - let mut state = Vec::with_capacity(12); + let mut state = Vec::with_capacity(WIDTH); for i in 0..4 { let a = vars.local_wires[i]; let b = vars.local_wires[i + 4]; @@ -193,7 +193,7 @@ where let b = vars.local_wires[i]; state.push(a + swap * (b - a)); } - for i in 8..12 { + for i in 8..WIDTH { state.push(vars.local_wires[i]); } @@ -265,7 +265,7 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(builder.mul_sub_extension(swap, swap, swap)); - let mut state = Vec::with_capacity(12); + let mut state = Vec::with_capacity(WIDTH); for i in 0..4 { let a = vars.local_wires[i]; let b = vars.local_wires[i + 4]; @@ -278,7 +278,7 @@ where let delta = builder.sub_extension(b, a); state.push(builder.mul_add_extension(swap, delta, a)); } - for i in 8..12 { + for i in 8..WIDTH { state.push(vars.local_wires[i]); } @@ -498,6 +498,7 @@ mod tests { use crate::field::field_types::Field; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::poseidon::PoseidonGate; + use crate::hash::hashing::SPONGE_WIDTH; use crate::hash::poseidon::Poseidon; use crate::iop::generator::generate_partial_witness; use crate::iop::wire::Wire; @@ -555,16 +556,14 @@ mod tests { #[test] fn low_degree() { type F = CrandallField; - const WIDTH: usize = 12; - let gate = PoseidonGate::::new(); + let gate = PoseidonGate::::new(); test_low_degree(gate) } #[test] fn eval_fns() -> Result<()> { type F = CrandallField; - const WIDTH: usize = 12; - let gate = PoseidonGate::::new(); + let gate = PoseidonGate::::new(); test_eval_fns(gate) } } diff --git a/src/hash/gmimc.rs b/src/hash/gmimc.rs index bb259d54..e64c940f 100644 --- a/src/hash/gmimc.rs +++ b/src/hash/gmimc.rs @@ -79,10 +79,18 @@ const CRANDALL_AND_GOLDILOCKS_ROUND_CONSTANTS: [u64; NUM_ROUNDS] = [ 0x780f22441e8dbc04, ]; +impl GMiMC<8> for CrandallField { + const ROUND_CONSTANTS: [u64; NUM_ROUNDS] = CRANDALL_AND_GOLDILOCKS_ROUND_CONSTANTS; +} + impl GMiMC<12> for CrandallField { const ROUND_CONSTANTS: [u64; NUM_ROUNDS] = CRANDALL_AND_GOLDILOCKS_ROUND_CONSTANTS; } +impl GMiMC<8> for GoldilocksField { + const ROUND_CONSTANTS: [u64; NUM_ROUNDS] = CRANDALL_AND_GOLDILOCKS_ROUND_CONSTANTS; +} + impl GMiMC<12> for GoldilocksField { const ROUND_CONSTANTS: [u64; NUM_ROUNDS] = CRANDALL_AND_GOLDILOCKS_ROUND_CONSTANTS; } diff --git a/src/hash/hashing.rs b/src/hash/hashing.rs index ae14e058..d031ebbb 100644 --- a/src/hash/hashing.rs +++ b/src/hash/hashing.rs @@ -1,5 +1,7 @@ //! Concrete instantiation of a hash function. +use std::convert::TryInto; + use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; use crate::hash::hash_types::{HashOut, HashOutTarget}; @@ -8,7 +10,7 @@ use crate::plonk::circuit_builder::CircuitBuilder; pub(crate) const SPONGE_RATE: usize = 8; pub(crate) const SPONGE_CAPACITY: usize = 4; -pub(crate) const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; +pub const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; pub(crate) const HASH_FAMILY: HashFamily = HashFamily::Poseidon; @@ -85,10 +87,12 @@ impl, const D: usize> CircuitBuilder { /// A one-way compression function which takes two ~256 bit inputs and returns a ~256 bit output. pub fn compress(x: HashOut, y: HashOut) -> HashOut { - let mut inputs = Vec::with_capacity(8); - inputs.extend(&x.elements); - inputs.extend(&y.elements); - hash_n_to_hash(inputs, false) + let mut perm_inputs = [F::ZERO; SPONGE_WIDTH]; + perm_inputs[..4].copy_from_slice(&x.elements); + perm_inputs[4..8].copy_from_slice(&y.elements); + HashOut { + elements: permute(perm_inputs)[..4].try_into().unwrap(), + } } /// If `pad` is enabled, the message is padded using the pad10*1 rule. In general this is required diff --git a/src/hash/merkle_proofs.rs b/src/hash/merkle_proofs.rs index 85719b51..ac62d8f0 100644 --- a/src/hash/merkle_proofs.rs +++ b/src/hash/merkle_proofs.rs @@ -7,7 +7,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget}; -use crate::hash::hashing::{compress, hash_or_noop}; +use crate::hash::hashing::{compress, hash_or_noop, SPONGE_WIDTH}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; @@ -64,14 +64,12 @@ impl, const D: usize> CircuitBuilder { proof: &MerkleProofTarget, ) { let zero = self.zero(); - let mut state: HashOutTarget = self.hash_or_noop(leaf_data); for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { - let perm_inputs = [state.elements, sibling.elements, [zero; 4]] - .concat() - .try_into() - .unwrap(); + let mut perm_inputs = [zero; SPONGE_WIDTH]; + perm_inputs[..4].copy_from_slice(&state.elements); + perm_inputs[4..8].copy_from_slice(&sibling.elements); let outputs = self.permute_swapped(perm_inputs, bit); state = HashOutTarget::from_vec(outputs[0..4].to_vec()); } @@ -101,16 +99,13 @@ impl, const D: usize> CircuitBuilder { proof: &MerkleProofTarget, ) { let zero = self.zero(); - let zero_x4 = [zero; 4]; - let mut state: HashOutTarget = self.hash_or_noop(leaf_data); for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { - let inputs = [state.elements, sibling.elements, zero_x4] - .concat() - .try_into() - .unwrap(); - let perm_outs = self.permute_swapped(inputs, bit); + let mut perm_inputs = [zero; SPONGE_WIDTH]; + perm_inputs[..4].copy_from_slice(&state.elements); + perm_inputs[4..8].copy_from_slice(&sibling.elements); + let perm_outs = self.permute_swapped(perm_inputs, bit); let hash_outs = perm_outs[0..4].try_into().unwrap(); state = HashOutTarget { elements: hash_outs,