Merge pull request #270 from mir-protocol/poseidon_8

Use `SPONGE_WIDTH` instead of hardcoded values in various places
This commit is contained in:
wborgeaud 2021-09-27 12:55:46 +02:00 committed by GitHub
commit 1a508d0c19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 29 deletions

View File

@ -5,6 +5,7 @@ use plonky2::field::crandall_field::CrandallField;
use plonky2::field::extension_field::Extendable; use plonky2::field::extension_field::Extendable;
use plonky2::field::field_types::RichField; use plonky2::field::field_types::RichField;
use plonky2::fri::FriConfig; use plonky2::fri::FriConfig;
use plonky2::hash::hashing::SPONGE_WIDTH;
use plonky2::iop::witness::PartialWitness; use plonky2::iop::witness::PartialWitness;
use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::CircuitConfig; use plonky2::plonk::circuit_data::CircuitConfig;
@ -41,7 +42,7 @@ fn bench_prove<F: RichField + Extendable<D>, const D: usize>() -> Result<()> {
let zero = builder.zero(); let zero = builder.zero();
let zero_ext = builder.zero_extension(); let zero_ext = builder.zero_extension();
let mut state = [zero; 12]; let mut state = [zero; SPONGE_WIDTH];
for _ in 0..10000 { for _ in 0..10000 {
state = builder.permute(state); state = builder.permute(state);
} }

View File

@ -103,7 +103,7 @@ where
let swap = vars.local_wires[Self::WIRE_SWAP]; let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(swap * (swap - F::Extension::ONE)); constraints.push(swap * (swap - F::Extension::ONE));
let mut state = Vec::with_capacity(12); let mut state = Vec::with_capacity(WIDTH);
for i in 0..4 { for i in 0..4 {
let a = vars.local_wires[i]; let a = vars.local_wires[i];
let b = vars.local_wires[i + 4]; let b = vars.local_wires[i + 4];
@ -114,7 +114,7 @@ where
let b = vars.local_wires[i]; let b = vars.local_wires[i];
state.push(a + swap * (b - a)); state.push(a + swap * (b - a));
} }
for i in 8..12 { for i in 8..WIDTH {
state.push(vars.local_wires[i]); state.push(vars.local_wires[i]);
} }
@ -182,7 +182,7 @@ where
let swap = vars.local_wires[Self::WIRE_SWAP]; let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(swap * (swap - F::ONE)); constraints.push(swap * (swap - F::ONE));
let mut state = Vec::with_capacity(12); let mut state = Vec::with_capacity(WIDTH);
for i in 0..4 { for i in 0..4 {
let a = vars.local_wires[i]; let a = vars.local_wires[i];
let b = vars.local_wires[i + 4]; let b = vars.local_wires[i + 4];
@ -193,7 +193,7 @@ where
let b = vars.local_wires[i]; let b = vars.local_wires[i];
state.push(a + swap * (b - a)); state.push(a + swap * (b - a));
} }
for i in 8..12 { for i in 8..WIDTH {
state.push(vars.local_wires[i]); state.push(vars.local_wires[i]);
} }
@ -265,7 +265,7 @@ where
let swap = vars.local_wires[Self::WIRE_SWAP]; let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(builder.mul_sub_extension(swap, swap, swap)); constraints.push(builder.mul_sub_extension(swap, swap, swap));
let mut state = Vec::with_capacity(12); let mut state = Vec::with_capacity(WIDTH);
for i in 0..4 { for i in 0..4 {
let a = vars.local_wires[i]; let a = vars.local_wires[i];
let b = vars.local_wires[i + 4]; let b = vars.local_wires[i + 4];
@ -278,7 +278,7 @@ where
let delta = builder.sub_extension(b, a); let delta = builder.sub_extension(b, a);
state.push(builder.mul_add_extension(swap, delta, a)); state.push(builder.mul_add_extension(swap, delta, a));
} }
for i in 8..12 { for i in 8..WIDTH {
state.push(vars.local_wires[i]); state.push(vars.local_wires[i]);
} }
@ -498,6 +498,7 @@ mod tests {
use crate::field::field_types::Field; use crate::field::field_types::Field;
use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; use crate::gates::gate_testing::{test_eval_fns, test_low_degree};
use crate::gates::poseidon::PoseidonGate; use crate::gates::poseidon::PoseidonGate;
use crate::hash::hashing::SPONGE_WIDTH;
use crate::hash::poseidon::Poseidon; use crate::hash::poseidon::Poseidon;
use crate::iop::generator::generate_partial_witness; use crate::iop::generator::generate_partial_witness;
use crate::iop::wire::Wire; use crate::iop::wire::Wire;
@ -555,16 +556,14 @@ mod tests {
#[test] #[test]
fn low_degree() { fn low_degree() {
type F = CrandallField; type F = CrandallField;
const WIDTH: usize = 12; let gate = PoseidonGate::<F, 4, SPONGE_WIDTH>::new();
let gate = PoseidonGate::<F, 4, WIDTH>::new();
test_low_degree(gate) test_low_degree(gate)
} }
#[test] #[test]
fn eval_fns() -> Result<()> { fn eval_fns() -> Result<()> {
type F = CrandallField; type F = CrandallField;
const WIDTH: usize = 12; let gate = PoseidonGate::<F, 4, SPONGE_WIDTH>::new();
let gate = PoseidonGate::<F, 4, WIDTH>::new();
test_eval_fns(gate) test_eval_fns(gate)
} }
} }

View File

@ -79,10 +79,18 @@ const CRANDALL_AND_GOLDILOCKS_ROUND_CONSTANTS: [u64; NUM_ROUNDS] = [
0x780f22441e8dbc04, 0x780f22441e8dbc04,
]; ];
impl GMiMC<8> for CrandallField {
const ROUND_CONSTANTS: [u64; NUM_ROUNDS] = CRANDALL_AND_GOLDILOCKS_ROUND_CONSTANTS;
}
impl GMiMC<12> for CrandallField { impl GMiMC<12> for CrandallField {
const ROUND_CONSTANTS: [u64; NUM_ROUNDS] = CRANDALL_AND_GOLDILOCKS_ROUND_CONSTANTS; const ROUND_CONSTANTS: [u64; NUM_ROUNDS] = CRANDALL_AND_GOLDILOCKS_ROUND_CONSTANTS;
} }
impl GMiMC<8> for GoldilocksField {
const ROUND_CONSTANTS: [u64; NUM_ROUNDS] = CRANDALL_AND_GOLDILOCKS_ROUND_CONSTANTS;
}
impl GMiMC<12> for GoldilocksField { impl GMiMC<12> for GoldilocksField {
const ROUND_CONSTANTS: [u64; NUM_ROUNDS] = CRANDALL_AND_GOLDILOCKS_ROUND_CONSTANTS; const ROUND_CONSTANTS: [u64; NUM_ROUNDS] = CRANDALL_AND_GOLDILOCKS_ROUND_CONSTANTS;
} }

View File

@ -1,5 +1,7 @@
//! Concrete instantiation of a hash function. //! Concrete instantiation of a hash function.
use std::convert::TryInto;
use crate::field::extension_field::Extendable; use crate::field::extension_field::Extendable;
use crate::field::field_types::RichField; use crate::field::field_types::RichField;
use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::hash::hash_types::{HashOut, HashOutTarget};
@ -8,7 +10,7 @@ use crate::plonk::circuit_builder::CircuitBuilder;
pub(crate) const SPONGE_RATE: usize = 8; pub(crate) const SPONGE_RATE: usize = 8;
pub(crate) const SPONGE_CAPACITY: usize = 4; pub(crate) const SPONGE_CAPACITY: usize = 4;
pub(crate) const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; pub const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY;
pub(crate) const HASH_FAMILY: HashFamily = HashFamily::Poseidon; pub(crate) const HASH_FAMILY: HashFamily = HashFamily::Poseidon;
@ -85,10 +87,12 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
/// A one-way compression function which takes two ~256 bit inputs and returns a ~256 bit output. /// A one-way compression function which takes two ~256 bit inputs and returns a ~256 bit output.
pub fn compress<F: RichField>(x: HashOut<F>, y: HashOut<F>) -> HashOut<F> { pub fn compress<F: RichField>(x: HashOut<F>, y: HashOut<F>) -> HashOut<F> {
let mut inputs = Vec::with_capacity(8); let mut perm_inputs = [F::ZERO; SPONGE_WIDTH];
inputs.extend(&x.elements); perm_inputs[..4].copy_from_slice(&x.elements);
inputs.extend(&y.elements); perm_inputs[4..8].copy_from_slice(&y.elements);
hash_n_to_hash(inputs, false) HashOut {
elements: permute(perm_inputs)[..4].try_into().unwrap(),
}
} }
/// If `pad` is enabled, the message is padded using the pad10*1 rule. In general this is required /// If `pad` is enabled, the message is padded using the pad10*1 rule. In general this is required

View File

@ -7,7 +7,7 @@ use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable; use crate::field::extension_field::Extendable;
use crate::field::field_types::{Field, RichField}; use crate::field::field_types::{Field, RichField};
use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget}; use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget};
use crate::hash::hashing::{compress, hash_or_noop}; use crate::hash::hashing::{compress, hash_or_noop, SPONGE_WIDTH};
use crate::hash::merkle_tree::MerkleCap; use crate::hash::merkle_tree::MerkleCap;
use crate::iop::target::{BoolTarget, Target}; use crate::iop::target::{BoolTarget, Target};
use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_builder::CircuitBuilder;
@ -64,14 +64,12 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
proof: &MerkleProofTarget, proof: &MerkleProofTarget,
) { ) {
let zero = self.zero(); let zero = self.zero();
let mut state: HashOutTarget = self.hash_or_noop(leaf_data); let mut state: HashOutTarget = self.hash_or_noop(leaf_data);
for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) {
let perm_inputs = [state.elements, sibling.elements, [zero; 4]] let mut perm_inputs = [zero; SPONGE_WIDTH];
.concat() perm_inputs[..4].copy_from_slice(&state.elements);
.try_into() perm_inputs[4..8].copy_from_slice(&sibling.elements);
.unwrap();
let outputs = self.permute_swapped(perm_inputs, bit); let outputs = self.permute_swapped(perm_inputs, bit);
state = HashOutTarget::from_vec(outputs[0..4].to_vec()); state = HashOutTarget::from_vec(outputs[0..4].to_vec());
} }
@ -101,16 +99,13 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
proof: &MerkleProofTarget, proof: &MerkleProofTarget,
) { ) {
let zero = self.zero(); let zero = self.zero();
let zero_x4 = [zero; 4];
let mut state: HashOutTarget = self.hash_or_noop(leaf_data); let mut state: HashOutTarget = self.hash_or_noop(leaf_data);
for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) {
let inputs = [state.elements, sibling.elements, zero_x4] let mut perm_inputs = [zero; SPONGE_WIDTH];
.concat() perm_inputs[..4].copy_from_slice(&state.elements);
.try_into() perm_inputs[4..8].copy_from_slice(&sibling.elements);
.unwrap(); let perm_outs = self.permute_swapped(perm_inputs, bit);
let perm_outs = self.permute_swapped(inputs, bit);
let hash_outs = perm_outs[0..4].try_into().unwrap(); let hash_outs = perm_outs[0..4].try_into().unwrap();
state = HashOutTarget { state = HashOutTarget {
elements: hash_outs, elements: hash_outs,