From 42a7ff9cc2bae22a78fff273b68550120e10a466 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 24 Sep 2021 13:06:07 +0200 Subject: [PATCH] Working --- src/bin/bench_recursion.rs | 2 +- src/field/field_types.rs | 2 +- src/gates/poseidon.rs | 19 +++++-------------- src/hash/gmimc.rs | 4 ++++ src/hash/hashing.rs | 12 +++++++----- src/hash/merkle_proofs.rs | 9 ++------- 6 files changed, 20 insertions(+), 28 deletions(-) diff --git a/src/bin/bench_recursion.rs b/src/bin/bench_recursion.rs index 117403fc..256864b1 100644 --- a/src/bin/bench_recursion.rs +++ b/src/bin/bench_recursion.rs @@ -41,7 +41,7 @@ fn bench_prove, const D: usize>() -> Result<()> { let zero = builder.zero(); let zero_ext = builder.zero_extension(); - let mut state = [zero; 12]; + let mut state = [zero; 8]; for _ in 0..10000 { state = builder.permute(state); } diff --git a/src/field/field_types.rs b/src/field/field_types.rs index 82f27d60..97db0948 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -16,7 +16,7 @@ use crate::hash::poseidon::Poseidon; use crate::util::bits_u64; /// A prime order field with the features we need to use it as a base field in our argument system. -pub trait RichField: PrimeField + GMiMC<12> + Poseidon<12> {} +pub trait RichField: PrimeField + GMiMC<8> + Poseidon<8> {} /// A finite field. pub trait Field: diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 3aef44ab..0d2a358b 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -103,7 +103,7 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(swap * (swap - F::Extension::ONE)); - let mut state = Vec::with_capacity(12); + let mut state = Vec::with_capacity(8); for i in 0..4 { let a = vars.local_wires[i]; let b = vars.local_wires[i + 4]; @@ -114,9 +114,6 @@ where let b = vars.local_wires[i]; state.push(a + swap * (b - a)); } - for i in 8..12 { - state.push(vars.local_wires[i]); - } let mut state: [F::Extension; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; @@ -182,7 +179,7 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(swap * (swap - F::ONE)); - let mut state = Vec::with_capacity(12); + let mut state = Vec::with_capacity(8); for i in 0..4 { let a = vars.local_wires[i]; let b = vars.local_wires[i + 4]; @@ -193,9 +190,6 @@ where let b = vars.local_wires[i]; state.push(a + swap * (b - a)); } - for i in 8..12 { - state.push(vars.local_wires[i]); - } let mut state: [F; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; @@ -265,7 +259,7 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(builder.mul_sub_extension(swap, swap, swap)); - let mut state = Vec::with_capacity(12); + let mut state = Vec::with_capacity(8); for i in 0..4 { let a = vars.local_wires[i]; let b = vars.local_wires[i + 4]; @@ -278,9 +272,6 @@ where let delta = builder.sub_extension(b, a); state.push(builder.mul_add_extension(swap, delta, a)); } - for i in 8..12 { - state.push(vars.local_wires[i]); - } let mut state: [ExtensionTarget; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; @@ -555,7 +546,7 @@ mod tests { #[test] fn low_degree() { type F = CrandallField; - const WIDTH: usize = 12; + const WIDTH: usize = 8; let gate = PoseidonGate::::new(); test_low_degree(gate) } @@ -563,7 +554,7 @@ mod tests { #[test] fn eval_fns() -> Result<()> { type F = CrandallField; - const WIDTH: usize = 12; + const WIDTH: usize = 8; let gate = PoseidonGate::::new(); test_eval_fns(gate) } diff --git a/src/hash/gmimc.rs b/src/hash/gmimc.rs index bb259d54..44af69d6 100644 --- a/src/hash/gmimc.rs +++ b/src/hash/gmimc.rs @@ -79,6 +79,10 @@ const CRANDALL_AND_GOLDILOCKS_ROUND_CONSTANTS: [u64; NUM_ROUNDS] = [ 0x780f22441e8dbc04, ]; +impl GMiMC<8> for CrandallField { + const ROUND_CONSTANTS: [u64; NUM_ROUNDS] = CRANDALL_AND_GOLDILOCKS_ROUND_CONSTANTS; +} + impl GMiMC<12> for CrandallField { const ROUND_CONSTANTS: [u64; NUM_ROUNDS] = CRANDALL_AND_GOLDILOCKS_ROUND_CONSTANTS; } diff --git a/src/hash/hashing.rs b/src/hash/hashing.rs index d5474cc4..dee3e320 100644 --- a/src/hash/hashing.rs +++ b/src/hash/hashing.rs @@ -1,5 +1,7 @@ //! Concrete instantiation of a hash function. +use std::convert::TryInto; + use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; use crate::gates::poseidon::PoseidonGate; @@ -7,7 +9,7 @@ use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; -pub(crate) const SPONGE_RATE: usize = 8; +pub(crate) const SPONGE_RATE: usize = 4; pub(crate) const SPONGE_CAPACITY: usize = 4; pub(crate) const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; @@ -86,10 +88,10 @@ impl, const D: usize> CircuitBuilder { /// A one-way compression function which takes two ~256 bit inputs and returns a ~256 bit output. pub fn compress(x: HashOut, y: HashOut) -> HashOut { - let mut inputs = Vec::with_capacity(8); - inputs.extend(&x.elements); - inputs.extend(&y.elements); - hash_n_to_hash(inputs, false) + let perm_inputs = [x.elements, y.elements].concat().try_into().unwrap(); + HashOut { + elements: permute(perm_inputs)[..4].try_into().unwrap(), + } } /// If `pad` is enabled, the message is padded using the pad10*1 rule. In general this is required diff --git a/src/hash/merkle_proofs.rs b/src/hash/merkle_proofs.rs index 85719b51..d595d61f 100644 --- a/src/hash/merkle_proofs.rs +++ b/src/hash/merkle_proofs.rs @@ -63,12 +63,10 @@ impl, const D: usize> CircuitBuilder { merkle_cap: &MerkleCapTarget, proof: &MerkleProofTarget, ) { - let zero = self.zero(); - let mut state: HashOutTarget = self.hash_or_noop(leaf_data); for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { - let perm_inputs = [state.elements, sibling.elements, [zero; 4]] + let perm_inputs = [state.elements, sibling.elements] .concat() .try_into() .unwrap(); @@ -100,13 +98,10 @@ impl, const D: usize> CircuitBuilder { merkle_cap: &MerkleCapTarget, proof: &MerkleProofTarget, ) { - let zero = self.zero(); - let zero_x4 = [zero; 4]; - let mut state: HashOutTarget = self.hash_or_noop(leaf_data); for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { - let inputs = [state.elements, sibling.elements, zero_x4] + let inputs = [state.elements, sibling.elements] .concat() .try_into() .unwrap();