diff --git a/src/bin/bench_recursion.rs b/src/bin/bench_recursion.rs index 256864b1..650bee11 100644 --- a/src/bin/bench_recursion.rs +++ b/src/bin/bench_recursion.rs @@ -5,6 +5,7 @@ use plonky2::field::crandall_field::CrandallField; use plonky2::field::extension_field::Extendable; use plonky2::field::field_types::RichField; use plonky2::fri::FriConfig; +use plonky2::hash::hashing::SPONGE_WIDTH; use plonky2::iop::witness::PartialWitness; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_data::CircuitConfig; @@ -41,7 +42,7 @@ fn bench_prove, const D: usize>() -> Result<()> { let zero = builder.zero(); let zero_ext = builder.zero_extension(); - let mut state = [zero; 8]; + let mut state = [zero; SPONGE_WIDTH]; for _ in 0..10000 { state = builder.permute(state); } diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 0d2a358b..bbc42b0d 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -5,6 +5,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::gates::gate::Gate; +use crate::hash::hashing::SPONGE_WIDTH; use crate::hash::poseidon; use crate::hash::poseidon::Poseidon; use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; @@ -103,7 +104,7 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(swap * (swap - F::Extension::ONE)); - let mut state = Vec::with_capacity(8); + let mut state = Vec::with_capacity(SPONGE_WIDTH); for i in 0..4 { let a = vars.local_wires[i]; let b = vars.local_wires[i + 4]; @@ -114,6 +115,9 @@ where let b = vars.local_wires[i]; state.push(a + swap * (b - a)); } + for i in 8..SPONGE_WIDTH { + state.push(vars.local_wires[i]); + } let mut state: [F::Extension; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; @@ -179,7 +183,7 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(swap * (swap - F::ONE)); - let mut state = Vec::with_capacity(8); + let mut state = Vec::with_capacity(SPONGE_WIDTH); for i in 0..4 { let a = vars.local_wires[i]; let b = vars.local_wires[i + 4]; @@ -190,6 +194,9 @@ where let b = vars.local_wires[i]; state.push(a + swap * (b - a)); } + for i in 8..SPONGE_WIDTH { + state.push(vars.local_wires[i]); + } let mut state: [F; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; @@ -259,7 +266,7 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(builder.mul_sub_extension(swap, swap, swap)); - let mut state = Vec::with_capacity(8); + let mut state = Vec::with_capacity(SPONGE_WIDTH); for i in 0..4 { let a = vars.local_wires[i]; let b = vars.local_wires[i + 4]; @@ -272,6 +279,9 @@ where let delta = builder.sub_extension(b, a); state.push(builder.mul_add_extension(swap, delta, a)); } + for i in 8..SPONGE_WIDTH { + state.push(vars.local_wires[i]); + } let mut state: [ExtensionTarget; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; @@ -489,6 +499,7 @@ mod tests { use crate::field::field_types::Field; use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::poseidon::PoseidonGate; + use crate::hash::hashing::SPONGE_WIDTH; use crate::hash::poseidon::Poseidon; use crate::iop::generator::generate_partial_witness; use crate::iop::wire::Wire; @@ -546,16 +557,14 @@ mod tests { #[test] fn low_degree() { type F = CrandallField; - const WIDTH: usize = 8; - let gate = PoseidonGate::::new(); + let gate = PoseidonGate::::new(); test_low_degree(gate) } #[test] fn eval_fns() -> Result<()> { type F = CrandallField; - const WIDTH: usize = 8; - let gate = PoseidonGate::::new(); + let gate = PoseidonGate::::new(); test_eval_fns(gate) } } diff --git a/src/hash/hashing.rs b/src/hash/hashing.rs index dee3e320..77ee1534 100644 --- a/src/hash/hashing.rs +++ b/src/hash/hashing.rs @@ -11,7 +11,7 @@ use crate::plonk::circuit_builder::CircuitBuilder; pub(crate) const SPONGE_RATE: usize = 4; pub(crate) const SPONGE_CAPACITY: usize = 4; -pub(crate) const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; +pub const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; pub(crate) const HASH_FAMILY: HashFamily = HashFamily::Poseidon; @@ -88,7 +88,9 @@ impl, const D: usize> CircuitBuilder { /// A one-way compression function which takes two ~256 bit inputs and returns a ~256 bit output. pub fn compress(x: HashOut, y: HashOut) -> HashOut { - let perm_inputs = [x.elements, y.elements].concat().try_into().unwrap(); + let mut perm_inputs = [F::ZERO; SPONGE_WIDTH]; + perm_inputs[..4].copy_from_slice(&x.elements); + perm_inputs[4..8].copy_from_slice(&y.elements); HashOut { elements: permute(perm_inputs)[..4].try_into().unwrap(), } diff --git a/src/hash/merkle_proofs.rs b/src/hash/merkle_proofs.rs index d595d61f..ac62d8f0 100644 --- a/src/hash/merkle_proofs.rs +++ b/src/hash/merkle_proofs.rs @@ -7,7 +7,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget}; -use crate::hash::hashing::{compress, hash_or_noop}; +use crate::hash::hashing::{compress, hash_or_noop, SPONGE_WIDTH}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; @@ -63,13 +63,13 @@ impl, const D: usize> CircuitBuilder { merkle_cap: &MerkleCapTarget, proof: &MerkleProofTarget, ) { + let zero = self.zero(); let mut state: HashOutTarget = self.hash_or_noop(leaf_data); for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { - let perm_inputs = [state.elements, sibling.elements] - .concat() - .try_into() - .unwrap(); + let mut perm_inputs = [zero; SPONGE_WIDTH]; + perm_inputs[..4].copy_from_slice(&state.elements); + perm_inputs[4..8].copy_from_slice(&sibling.elements); let outputs = self.permute_swapped(perm_inputs, bit); state = HashOutTarget::from_vec(outputs[0..4].to_vec()); } @@ -98,14 +98,14 @@ impl, const D: usize> CircuitBuilder { merkle_cap: &MerkleCapTarget, proof: &MerkleProofTarget, ) { + let zero = self.zero(); let mut state: HashOutTarget = self.hash_or_noop(leaf_data); for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { - let inputs = [state.elements, sibling.elements] - .concat() - .try_into() - .unwrap(); - let perm_outs = self.permute_swapped(inputs, bit); + let mut perm_inputs = [zero; SPONGE_WIDTH]; + perm_inputs[..4].copy_from_slice(&state.elements); + perm_inputs[4..8].copy_from_slice(&sibling.elements); + let perm_outs = self.permute_swapped(perm_inputs, bit); let hash_outs = perm_outs[0..4].try_into().unwrap(); state = HashOutTarget { elements: hash_outs,