diff --git a/src/gadgets/hash.rs b/src/gadgets/hash.rs index 01519912..82de1095 100644 --- a/src/gadgets/hash.rs +++ b/src/gadgets/hash.rs @@ -3,8 +3,10 @@ use std::convert::TryInto; use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; use crate::gates::gmimc::GMiMCGate; +use crate::gates::poseidon::PoseidonGate; use crate::hash::gmimc::GMiMC; use crate::hash::hashing::{HashFamily, HASH_FAMILY}; +use crate::hash::poseidon::Poseidon; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; use crate::plonk::circuit_builder::CircuitBuilder; @@ -13,7 +15,8 @@ use crate::plonk::circuit_builder::CircuitBuilder; impl, const D: usize> CircuitBuilder { pub fn permute(&mut self, inputs: [Target; W]) -> [Target; W] where - F: GMiMC, + F: GMiMC + Poseidon, + [(); W - 1]: , { // We don't want to swap any inputs, so set that wire to 0. let _false = self._false(); @@ -28,11 +31,12 @@ impl, const D: usize> CircuitBuilder { swap: BoolTarget, ) -> [Target; W] where - F: GMiMC, + F: GMiMC + Poseidon, + [(); W - 1]: , { match HASH_FAMILY { HashFamily::GMiMC => self.gmimc_permute_swapped(inputs, swap), - HashFamily::Poseidon => todo!(), + HashFamily::Poseidon => self.poseidon_permute_swapped(inputs, swap), } } @@ -79,4 +83,38 @@ impl, const D: usize> CircuitBuilder { .try_into() .unwrap() } + + /// Conditionally swap two chunks of the inputs (useful in verifying Merkle proofs), then apply + /// the Poseidon permutation. + pub(crate) fn poseidon_permute_swapped( + &mut self, + inputs: [Target; W], + swap: BoolTarget, + ) -> [Target; W] + where + F: Poseidon, + [(); W - 1]: , + { + let gate_type = PoseidonGate::::new(); + let gate = self.add_gate(gate_type, vec![]); + + // We don't want to swap any inputs, so set that wire to 0. + let swap_wire = PoseidonGate::::WIRE_SWAP; + let swap_wire = Target::wire(gate, swap_wire); + self.connect(swap.target, swap_wire); + + // Route input wires. + for i in 0..W { + let in_wire = PoseidonGate::::wire_input(i); + let in_wire = Target::wire(gate, in_wire); + self.connect(inputs[i], in_wire); + } + + // Collect output wires. + (0..W) + .map(|i| Target::wire(gate, PoseidonGate::::wire_output(i))) + .collect::>() + .try_into() + .unwrap() + } } diff --git a/src/gates/gmimc.rs b/src/gates/gmimc.rs index 9df6360c..5c031cfe 100644 --- a/src/gates/gmimc.rs +++ b/src/gates/gmimc.rs @@ -13,13 +13,11 @@ use crate::iop::witness::{PartitionWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; -/// Evaluates a full GMiMC permutation with 12 state elements, and writes the output to the next -/// gate's first `width` wires (which could be the input of another `GMiMCGate`). +/// Evaluates a full GMiMC permutation with 12 state elements. /// /// This also has some extra features to make it suitable for efficiently verifying Merkle proofs. /// It has a flag which can be used to swap the first four inputs with the next four, for ordering -/// sibling digests. It also has an accumulator that computes the weighted sum of these flags, for -/// computing the index of the leaf based on these swap bits. +/// sibling digests. #[derive(Debug)] pub struct GMiMCGate< F: RichField + Extendable + GMiMC, diff --git a/src/gates/mod.rs b/src/gates/mod.rs index 053ba3fa..9f0be32a 100644 --- a/src/gates/mod.rs +++ b/src/gates/mod.rs @@ -12,6 +12,7 @@ pub mod gmimc; pub mod insertion; pub mod interpolation; pub mod noop; +pub mod poseidon; pub(crate) mod public_input; pub mod random_access; pub mod reducing; diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs new file mode 100644 index 00000000..3aef44ab --- /dev/null +++ b/src/gates/poseidon.rs @@ -0,0 +1,570 @@ +use std::convert::TryInto; +use std::marker::PhantomData; + +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::field_types::{Field, RichField}; +use crate::gates::gate::Gate; +use crate::hash::poseidon; +use crate::hash::poseidon::Poseidon; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::wire::Wire; +use crate::iop::witness::{PartitionWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// Evaluates a full Poseidon permutation with 12 state elements. +/// +/// This also has some extra features to make it suitable for efficiently verifying Merkle proofs. +/// It has a flag which can be used to swap the first four inputs with the next four, for ordering +/// sibling digests. +#[derive(Debug)] +pub struct PoseidonGate< + F: RichField + Extendable + Poseidon, + const D: usize, + const WIDTH: usize, +> where + [(); WIDTH - 1]: , +{ + _phantom: PhantomData, +} + +impl + Poseidon, const D: usize, const WIDTH: usize> + PoseidonGate +where + [(); WIDTH - 1]: , +{ + pub fn new() -> Self { + PoseidonGate { + _phantom: PhantomData, + } + } + + /// The wire index for the `i`th input to the permutation. + pub fn wire_input(i: usize) -> usize { + i + } + + /// The wire index for the `i`th output to the permutation. + pub fn wire_output(i: usize) -> usize { + WIDTH + i + } + + /// If this is set to 1, the first four inputs will be swapped with the next four inputs. This + /// is useful for ordering hashes in Merkle proofs. Otherwise, this should be set to 0. + pub const WIRE_SWAP: usize = 2 * WIDTH; + + /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the first set + /// of full rounds. + fn wire_full_sbox_0(round: usize, i: usize) -> usize { + debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < WIDTH); + 2 * WIDTH + 1 + WIDTH * round + i + } + + /// A wire which stores the input of the S-box of the `round`-th round of the partial rounds. + fn wire_partial_sbox(round: usize) -> usize { + debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); + 2 * WIDTH + 1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS + round + } + + /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the second set + /// of full rounds. + fn wire_full_sbox_1(round: usize, i: usize) -> usize { + debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < WIDTH); + 2 * WIDTH + + 1 + + WIDTH * (poseidon::HALF_N_FULL_ROUNDS + round) + + poseidon::N_PARTIAL_ROUNDS + + i + } + + /// End of wire indices, exclusive. + fn end() -> usize { + 2 * WIDTH + 1 + WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + } +} + +impl + Poseidon, const D: usize, const WIDTH: usize> Gate + for PoseidonGate +where + [(); WIDTH - 1]: , +{ + fn id(&self) -> String { + format!(" {:?}", WIDTH, self) + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + // Assert that `swap` is binary. + let swap = vars.local_wires[Self::WIRE_SWAP]; + constraints.push(swap * (swap - F::Extension::ONE)); + + let mut state = Vec::with_capacity(12); + for i in 0..4 { + let a = vars.local_wires[i]; + let b = vars.local_wires[i + 4]; + state.push(a + swap * (b - a)); + } + for i in 0..4 { + let a = vars.local_wires[i + 4]; + let b = vars.local_wires[i]; + state.push(a + swap * (b - a)); + } + for i in 8..12 { + state.push(vars.local_wires[i]); + } + + let mut state: [F::Extension; WIDTH] = state.try_into().unwrap(); + let mut round_ctr = 0; + + // First set of full rounds. + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer_field(&mut state, round_ctr); + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } + >::sbox_layer_field(&mut state); + state = >::mds_layer_field(&state); + round_ctr += 1; + } + + // Partial rounds. + >::partial_first_constant_layer(&mut state); + state = >::mds_partial_layer_init(&mut state); + for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { + let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; + constraints.push(state[0] - sbox_in); + state[0] = >::sbox_monomial(sbox_in); + state[0] += F::Extension::from_canonical_u64( + >::FAST_PARTIAL_ROUND_CONSTANTS[r], + ); + state = >::mds_partial_layer_fast_field(&state, r); + } + let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; + constraints.push(state[0] - sbox_in); + state[0] = >::sbox_monomial(sbox_in); + state = >::mds_partial_layer_fast_field( + &state, + poseidon::N_PARTIAL_ROUNDS - 1, + ); + round_ctr += poseidon::N_PARTIAL_ROUNDS; + + // Second set of full rounds. + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer_field(&mut state, round_ctr); + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } + >::sbox_layer_field(&mut state); + state = >::mds_layer_field(&state); + round_ctr += 1; + } + + for i in 0..WIDTH { + constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]); + } + + constraints + } + + fn eval_unfiltered_base(&self, vars: EvaluationVarsBase) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + // Assert that `swap` is binary. + let swap = vars.local_wires[Self::WIRE_SWAP]; + constraints.push(swap * (swap - F::ONE)); + + let mut state = Vec::with_capacity(12); + for i in 0..4 { + let a = vars.local_wires[i]; + let b = vars.local_wires[i + 4]; + state.push(a + swap * (b - a)); + } + for i in 0..4 { + let a = vars.local_wires[i + 4]; + let b = vars.local_wires[i]; + state.push(a + swap * (b - a)); + } + for i in 8..12 { + state.push(vars.local_wires[i]); + } + + let mut state: [F; WIDTH] = state.try_into().unwrap(); + let mut round_ctr = 0; + + // First set of full rounds. + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer_field(&mut state, round_ctr); + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } + >::sbox_layer_field(&mut state); + state = >::mds_layer_field(&state); + round_ctr += 1; + } + + // Partial rounds. + >::partial_first_constant_layer(&mut state); + state = >::mds_partial_layer_init(&mut state); + for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { + let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; + constraints.push(state[0] - sbox_in); + state[0] = >::sbox_monomial(sbox_in); + state[0] += + F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]); + state = >::mds_partial_layer_fast_field(&state, r); + } + let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; + constraints.push(state[0] - sbox_in); + state[0] = >::sbox_monomial(sbox_in); + state = >::mds_partial_layer_fast_field( + &state, + poseidon::N_PARTIAL_ROUNDS - 1, + ); + round_ctr += poseidon::N_PARTIAL_ROUNDS; + + // Second set of full rounds. + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer_field(&mut state, round_ctr); + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } + >::sbox_layer_field(&mut state); + state = >::mds_layer_field(&state); + round_ctr += 1; + } + + for i in 0..WIDTH { + constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]); + } + + constraints + } + + fn eval_unfiltered_recursively( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let one = builder.one_extension(); + let mut constraints = Vec::with_capacity(self.num_constraints()); + + // Assert that `swap` is binary. + let swap = vars.local_wires[Self::WIRE_SWAP]; + constraints.push(builder.mul_sub_extension(swap, swap, swap)); + + let mut state = Vec::with_capacity(12); + for i in 0..4 { + let a = vars.local_wires[i]; + let b = vars.local_wires[i + 4]; + let delta = builder.sub_extension(b, a); + state.push(builder.mul_add_extension(swap, delta, a)); + } + for i in 0..4 { + let a = vars.local_wires[i + 4]; + let b = vars.local_wires[i]; + let delta = builder.sub_extension(b, a); + state.push(builder.mul_add_extension(swap, delta, a)); + } + for i in 8..12 { + state.push(vars.local_wires[i]); + } + + let mut state: [ExtensionTarget; WIDTH] = state.try_into().unwrap(); + let mut round_ctr = 0; + + // First set of full rounds. + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer_recursive(builder, &mut state, round_ctr); + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(builder.sub_extension(state[i], sbox_in)); + state[i] = sbox_in; + } + >::sbox_layer_recursive(builder, &mut state); + state = >::mds_layer_recursive(builder, &state); + round_ctr += 1; + } + + // Partial rounds. + >::partial_first_constant_layer_recursive(builder, &mut state); + state = >::mds_partial_layer_init_recursive(builder, &mut state); + for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { + let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)]; + constraints.push(builder.sub_extension(state[0], sbox_in)); + state[0] = >::sbox_monomial_recursive(builder, sbox_in); + state[0] = builder.arithmetic_extension( + F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]), + F::ONE, + one, + one, + state[0], + ); + state = >::mds_partial_layer_fast_recursive(builder, &state, r); + } + let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)]; + constraints.push(builder.sub_extension(state[0], sbox_in)); + state[0] = >::sbox_monomial_recursive(builder, sbox_in); + state = >::mds_partial_layer_fast_recursive( + builder, + &state, + poseidon::N_PARTIAL_ROUNDS - 1, + ); + round_ctr += poseidon::N_PARTIAL_ROUNDS; + + // Second set of full rounds. + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer_recursive(builder, &mut state, round_ctr); + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)]; + constraints.push(builder.sub_extension(state[i], sbox_in)); + state[i] = sbox_in; + } + >::sbox_layer_recursive(builder, &mut state); + state = >::mds_layer_recursive(builder, &state); + round_ctr += 1; + } + + for i in 0..WIDTH { + constraints + .push(builder.sub_extension(state[i], vars.local_wires[Self::wire_output(i)])); + } + + constraints + } + + fn generators( + &self, + gate_index: usize, + _local_constants: &[F], + ) -> Vec>> { + let gen = PoseidonGenerator:: { + gate_index, + _phantom: PhantomData, + }; + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + Self::end() + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 7 + } + + fn num_constraints(&self) -> usize { + WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + WIDTH + 1 + } +} + +#[derive(Debug)] +struct PoseidonGenerator< + F: RichField + Extendable + Poseidon, + const D: usize, + const WIDTH: usize, +> where + [(); WIDTH - 1]: , +{ + gate_index: usize, + _phantom: PhantomData, +} + +impl + Poseidon, const D: usize, const WIDTH: usize> + SimpleGenerator for PoseidonGenerator +where + [(); WIDTH - 1]: , +{ + fn dependencies(&self) -> Vec { + (0..WIDTH) + .map(|i| PoseidonGate::::wire_input(i)) + .chain(Some(PoseidonGate::::WIRE_SWAP)) + .map(|input| Target::wire(self.gate_index, input)) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |input| Wire { + gate: self.gate_index, + input, + }; + + let mut state = (0..WIDTH) + .map(|i| { + witness.get_wire(Wire { + gate: self.gate_index, + input: PoseidonGate::::wire_input(i), + }) + }) + .collect::>(); + + let swap_value = witness.get_wire(Wire { + gate: self.gate_index, + input: PoseidonGate::::WIRE_SWAP, + }); + debug_assert!(swap_value == F::ZERO || swap_value == F::ONE); + if swap_value == F::ONE { + for i in 0..4 { + state.swap(i, 4 + i); + } + } + + let mut state: [F; WIDTH] = state.try_into().unwrap(); + let mut round_ctr = 0; + + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer_field(&mut state, round_ctr); + for i in 0..WIDTH { + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_full_sbox_0(r, i)), + state[i], + ); + } + >::sbox_layer_field(&mut state); + state = >::mds_layer_field(&state); + round_ctr += 1; + } + + >::partial_first_constant_layer(&mut state); + state = >::mds_partial_layer_init(&mut state); + for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) { + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_partial_sbox(r)), + state[0], + ); + state[0] = >::sbox_monomial(state[0]); + state[0] += + F::from_canonical_u64(>::FAST_PARTIAL_ROUND_CONSTANTS[r]); + state = >::mds_partial_layer_fast_field(&state, r); + } + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_partial_sbox( + poseidon::N_PARTIAL_ROUNDS - 1, + )), + state[0], + ); + state[0] = >::sbox_monomial(state[0]); + state = >::mds_partial_layer_fast_field( + &state, + poseidon::N_PARTIAL_ROUNDS - 1, + ); + round_ctr += poseidon::N_PARTIAL_ROUNDS; + + for r in 0..poseidon::HALF_N_FULL_ROUNDS { + >::constant_layer_field(&mut state, round_ctr); + for i in 0..WIDTH { + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_full_sbox_1(r, i)), + state[i], + ); + } + >::sbox_layer_field(&mut state); + state = >::mds_layer_field(&state); + round_ctr += 1; + } + + for i in 0..WIDTH { + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_output(i)), + state[i], + ); + } + } +} + +#[cfg(test)] +mod tests { + use std::convert::TryInto; + + use anyhow::Result; + + use crate::field::crandall_field::CrandallField; + use crate::field::field_types::Field; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::gates::poseidon::PoseidonGate; + use crate::hash::poseidon::Poseidon; + use crate::iop::generator::generate_partial_witness; + use crate::iop::wire::Wire; + use crate::iop::witness::{PartialWitness, Witness}; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + + #[test] + fn generated_output() { + type F = CrandallField; + const WIDTH: usize = 12; + + let config = CircuitConfig { + num_wires: 143, + ..CircuitConfig::large_config() + }; + let mut builder = CircuitBuilder::new(config); + type Gate = PoseidonGate; + let gate = Gate::new(); + let gate_index = builder.add_gate(gate, vec![]); + let circuit = builder.build_prover(); + + let permutation_inputs = (0..WIDTH).map(F::from_canonical_usize).collect::>(); + + let mut inputs = PartialWitness::new(); + inputs.set_wire( + Wire { + gate: gate_index, + input: Gate::WIRE_SWAP, + }, + F::ZERO, + ); + for i in 0..WIDTH { + inputs.set_wire( + Wire { + gate: gate_index, + input: Gate::wire_input(i), + }, + permutation_inputs[i], + ); + } + + let witness = generate_partial_witness(inputs, &circuit.prover_only); + + let expected_outputs: [F; WIDTH] = F::poseidon(permutation_inputs.try_into().unwrap()); + for i in 0..WIDTH { + let out = witness.get_wire(Wire { + gate: 0, + input: Gate::wire_output(i), + }); + assert_eq!(out, expected_outputs[i]); + } + } + + #[test] + fn low_degree() { + type F = CrandallField; + const WIDTH: usize = 12; + let gate = PoseidonGate::::new(); + test_low_degree(gate) + } + + #[test] + fn eval_fns() -> Result<()> { + type F = CrandallField; + const WIDTH: usize = 12; + let gate = PoseidonGate::::new(); + test_eval_fns(gate) + } +} diff --git a/src/hash/hashing.rs b/src/hash/hashing.rs index 9b4791d3..00cd0a74 100644 --- a/src/hash/hashing.rs +++ b/src/hash/hashing.rs @@ -2,6 +2,7 @@ use crate::field::extension_field::Extendable; use crate::field::field_types::RichField; +use crate::gates::poseidon::PoseidonGate; use crate::hash::hash_types::{HashOut, HashOutTarget}; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; @@ -10,7 +11,8 @@ pub(crate) const SPONGE_RATE: usize = 8; pub(crate) const SPONGE_CAPACITY: usize = 4; pub(crate) const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; -pub(crate) const HASH_FAMILY: HashFamily = HashFamily::GMiMC; +pub(crate) const HASH_FAMILY: HashFamily = HashFamily::Poseidon; +pub(crate) type HashGate = PoseidonGate; pub(crate) enum HashFamily { GMiMC, diff --git a/src/hash/merkle_proofs.rs b/src/hash/merkle_proofs.rs index 22b6f318..7a176dd9 100644 --- a/src/hash/merkle_proofs.rs +++ b/src/hash/merkle_proofs.rs @@ -6,9 +6,8 @@ use serde::{Deserialize, Serialize}; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; -use crate::gates::gmimc::GMiMCGate; use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget}; -use crate::hash::hashing::{compress, hash_or_noop}; +use crate::hash::hashing::{compress, hash_or_noop, HashGate}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; @@ -74,7 +73,7 @@ impl, const D: usize> CircuitBuilder { .concat() .try_into() .unwrap(); - let outputs = self.gmimc_permute_swapped(inputs, bit); + let outputs = self.permute_swapped(inputs, bit); state = HashOutTarget::from_vec(outputs[0..4].to_vec()); } @@ -107,10 +106,10 @@ impl, const D: usize> CircuitBuilder { let mut state: HashOutTarget = self.hash_or_noop(leaf_data); for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { - let gate_type = GMiMCGate::::new(); + let gate_type = HashGate::::new(); let gate = self.add_gate(gate_type, vec![]); - let swap_wire = GMiMCGate::::WIRE_SWAP; + let swap_wire = HashGate::::WIRE_SWAP; let swap_wire = Target::Wire(Wire { gate, input: swap_wire, @@ -121,7 +120,7 @@ impl, const D: usize> CircuitBuilder { .map(|i| { Target::Wire(Wire { gate, - input: GMiMCGate::::wire_input(i), + input: HashGate::::wire_input(i), }) }) .collect::>(); @@ -137,7 +136,7 @@ impl, const D: usize> CircuitBuilder { .map(|i| { Target::Wire(Wire { gate, - input: GMiMCGate::::wire_output(i), + input: HashGate::::wire_output(i), }) }) .collect(), diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index cec6ea71..ca519634 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -7,7 +7,10 @@ use std::convert::TryInto; use unroll::unroll_for_loops; use crate::field::crandall_field::CrandallField; -use crate::field::field_types::PrimeField; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::{Extendable, FieldExtension}; +use crate::field::field_types::{PrimeField, RichField}; +use crate::plonk::circuit_builder::CircuitBuilder; // The number of full rounds and partial rounds is given by the // calc_round_numbers.py script. They happen to be the same for both @@ -15,9 +18,9 @@ use crate::field::field_types::PrimeField; // // NB: Changing any of these values will require regenerating all of // the precomputed constant arrays in this file. -const HALF_N_FULL_ROUNDS: usize = 4; -const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS; -const N_PARTIAL_ROUNDS: usize = 22; +pub(crate) const HALF_N_FULL_ROUNDS: usize = 4; +pub(crate) const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS; +pub(crate) const N_PARTIAL_ROUNDS: usize = 22; const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS; const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. :) @@ -25,7 +28,7 @@ const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. : /// `generate_constants` about how these were generated. We include enough for a WIDTH of 12; /// smaller widths just use a subset. #[rustfmt::skip] -const ALL_ROUND_CONSTANTS: [u64; MAX_WIDTH * N_ROUNDS] = [ +pub const ALL_ROUND_CONSTANTS: [u64; MAX_WIDTH * N_ROUNDS] = [ // WARNING: These must be in 0..CrandallField::ORDER (i.e. canonical form). If this condition is // not met, some platform-specific implementation of constant_layer may return incorrect // results. @@ -165,6 +168,49 @@ where res } + #[inline(always)] + #[unroll_for_loops] + /// Same as `mds_row_shf` for field extensions of `Self`. + fn mds_row_shf_field, const D: usize>( + r: usize, + v: &[F; WIDTH], + ) -> F { + debug_assert!(r < WIDTH); + let mut res = F::ZERO; + + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + res += v[(i + r) % WIDTH] * F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]); + } + } + + res + } + + /// Recursive version of `mds_row_shf`. + fn mds_row_shf_recursive, const D: usize>( + builder: &mut CircuitBuilder, + r: usize, + v: &[ExtensionTarget; WIDTH], + ) -> ExtensionTarget { + let one = builder.one_extension(); + debug_assert!(r < WIDTH); + let mut res = builder.zero_extension(); + + for i in 0..WIDTH { + res = builder.arithmetic_extension( + F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]), + F::ONE, + one, + v[(i + r) % WIDTH], + res, + ); + } + + res + } + #[inline(always)] #[unroll_for_loops] fn mds_layer(state_: &[Self; WIDTH]) -> [Self; WIDTH] { @@ -188,19 +234,72 @@ where #[inline(always)] #[unroll_for_loops] - fn partial_first_constant_layer(state: &mut [Self; WIDTH]) { + /// Same as `mds_layer` for field extensions of `Self`. + fn mds_layer_field, const D: usize>( + state: &[F; WIDTH], + ) -> [F; WIDTH] { + let mut result = [F::ZERO; WIDTH]; + + assert!(WIDTH <= 12); + for r in 0..12 { + if r < WIDTH { + result[r] = Self::mds_row_shf_field(r, state); + } + } + + result + } + + /// Recursive version of `mds_layer`. + fn mds_layer_recursive, const D: usize>( + builder: &mut CircuitBuilder, + state: &[ExtensionTarget; WIDTH], + ) -> [ExtensionTarget; WIDTH] { + let mut result = [builder.zero_extension(); WIDTH]; + + for r in 0..WIDTH { + result[r] = Self::mds_row_shf_recursive(builder, r, state); + } + + result + } + + #[inline(always)] + #[unroll_for_loops] + fn partial_first_constant_layer, const D: usize>( + state: &mut [F; WIDTH], + ) { assert!(WIDTH <= 12); for i in 0..12 { if i < WIDTH { - state[i] += Self::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]); + state[i] += F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]); } } } + /// Recursive version of `partial_first_constant_layer`. + fn partial_first_constant_layer_recursive, const D: usize>( + builder: &mut CircuitBuilder, + state: &mut [ExtensionTarget; WIDTH], + ) { + let one = builder.one_extension(); + for i in 0..WIDTH { + state[i] = builder.arithmetic_extension( + F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]), + F::ONE, + one, + one, + state[i], + ); + } + } + #[inline(always)] #[unroll_for_loops] - fn mds_partial_layer_init(state: &[Self; WIDTH]) -> [Self; WIDTH] { - let mut result = [Self::ZERO; WIDTH]; + fn mds_partial_layer_init, const D: usize>( + state: &[F; WIDTH], + ) -> [F; WIDTH] { + let mut result = [F::ZERO; WIDTH]; // Initial matrix has first row/column = [1, 0, ..., 0]; @@ -216,7 +315,7 @@ where // NB: FAST_PARTIAL_ROUND_INITIAL_MATRIX is stored in // column-major order so that this dot product is cache // friendly. - let t = Self::from_canonical_u64( + let t = F::from_canonical_u64( Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1], ); result[c] += state[r] * t; @@ -227,6 +326,30 @@ where result } + /// Recursive version of `mds_partial_layer_init`. + fn mds_partial_layer_init_recursive, const D: usize>( + builder: &mut CircuitBuilder, + state: &[ExtensionTarget; WIDTH], + ) -> [ExtensionTarget; WIDTH] { + let one = builder.one_extension(); + let mut result = [builder.zero_extension(); WIDTH]; + + result[0] = state[0]; + + for c in 1..WIDTH { + assert!(WIDTH <= 12); + for r in 1..12 { + if r < WIDTH { + let t = F::from_canonical_u64( + Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1], + ); + result[c] = builder.arithmetic_extension(t, F::ONE, one, state[r], result[c]); + } + } + } + result + } + /// Computes s*A where s is the state row vector and A is the matrix /// /// [ M_00 | v ] @@ -263,6 +386,70 @@ where result } + #[inline(always)] + #[unroll_for_loops] + /// Same as `mds_partial_layer_fast` for field extensions of `Self`. + fn mds_partial_layer_fast_field, const D: usize>( + state: &[F; WIDTH], + r: usize, + ) -> [F; WIDTH] { + let s0 = state[0]; + let mut d = s0 * F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]); + assert!(WIDTH <= 12); + for i in 1..12 { + if i < WIDTH { + let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]); + d += state[i] * t; + } + } + + // result = [d] concat [state[0] * v + state[shift up by 1]] + let mut result = [F::ZERO; WIDTH]; + result[0] = d; + assert!(WIDTH <= 12); + for i in 1..12 { + if i < WIDTH { + let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]); + result[i] = state[0] * t + state[i]; + } + } + result + } + + /// Recursive version of `mds_partial_layer_fast`. + fn mds_partial_layer_fast_recursive, const D: usize>( + builder: &mut CircuitBuilder, + state: &[ExtensionTarget; WIDTH], + r: usize, + ) -> [ExtensionTarget; WIDTH] { + let zero = builder.zero_extension(); + let one = builder.one_extension(); + + let s0 = state[0]; + let mut d = builder.arithmetic_extension( + F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]), + F::ONE, + one, + s0, + zero, + ); + for i in 1..WIDTH { + let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]); + d = builder.arithmetic_extension(t, F::ONE, one, state[i], d); + } + + let mut result = [zero; WIDTH]; + result[0] = d; + assert!(WIDTH <= 12); + for i in 1..12 { + if i < WIDTH { + let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]); + result[i] = builder.arithmetic_extension(t, F::ONE, one, state[0], state[i]); + } + } + result + } + #[inline(always)] #[unroll_for_loops] fn constant_layer(state: &mut [Self; WIDTH], round_ctr: usize) { @@ -275,7 +462,40 @@ where } #[inline(always)] - fn sbox_monomial(x: Self) -> Self { + #[unroll_for_loops] + /// Same as `constant_layer` for field extensions of `Self`. + fn constant_layer_field, const D: usize>( + state: &mut [F; WIDTH], + round_ctr: usize, + ) { + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + state[i] += F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]); + } + } + } + + /// Recursive version of `constant_layer`. + fn constant_layer_recursive, const D: usize>( + builder: &mut CircuitBuilder, + state: &mut [ExtensionTarget; WIDTH], + round_ctr: usize, + ) { + let one = builder.one_extension(); + for i in 0..WIDTH { + state[i] = builder.arithmetic_extension( + F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]), + F::ONE, + one, + one, + state[i], + ); + } + } + + #[inline(always)] + fn sbox_monomial, const D: usize>(x: F) -> F { // x |--> x^7 let x2 = x * x; let x4 = x2 * x2; @@ -283,6 +503,15 @@ where x3 * x4 } + /// Recursive version of `sbox_monomial`. + fn sbox_monomial_recursive, const D: usize>( + builder: &mut CircuitBuilder, + x: ExtensionTarget, + ) -> ExtensionTarget { + // x |--> x^7 + builder.exp_u64_extension(x, 7) + } + #[inline(always)] #[unroll_for_loops] fn sbox_layer(state: &mut [Self; WIDTH]) { @@ -294,6 +523,30 @@ where } } + #[inline(always)] + #[unroll_for_loops] + /// Same as `sbox_layer` for field extensions of `Self`. + fn sbox_layer_field, const D: usize>( + state: &mut [F; WIDTH], + ) { + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + state[i] = Self::sbox_monomial(state[i]); + } + } + } + + /// Recursive version of `sbox_layer`. + fn sbox_layer_recursive, const D: usize>( + builder: &mut CircuitBuilder, + state: &mut [ExtensionTarget; WIDTH], + ) { + for i in 0..WIDTH { + state[i] = Self::sbox_monomial_recursive(builder, state[i]); + } + } + #[inline] fn full_rounds(state: &mut [Self; WIDTH], round_ctr: &mut usize) { for _ in 0..HALF_N_FULL_ROUNDS { diff --git a/src/iop/challenger.rs b/src/iop/challenger.rs index 47d57db8..d33c19ce 100644 --- a/src/iop/challenger.rs +++ b/src/iop/challenger.rs @@ -4,7 +4,7 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field_types::RichField; use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget}; -use crate::hash::hashing::{SPONGE_RATE, SPONGE_WIDTH}; +use crate::hash::hashing::{permute, SPONGE_RATE, SPONGE_WIDTH}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; @@ -105,7 +105,7 @@ impl Challenger { if self.output_buffer.is_empty() { // Evaluate the permutation to produce `r` new outputs. - self.sponge_state = F::gmimc_permute(self.sponge_state); + self.sponge_state = permute(self.sponge_state); self.output_buffer = self.sponge_state[0..SPONGE_RATE].to_vec(); } @@ -160,7 +160,7 @@ impl Challenger { } // Apply the permutation. - self.sponge_state = F::gmimc_permute(self.sponge_state); + self.sponge_state = permute(self.sponge_state); } self.output_buffer = self.sponge_state[0..SPONGE_RATE].to_vec(); @@ -377,7 +377,7 @@ mod tests { } let config = CircuitConfig::large_config(); - let mut builder = CircuitBuilder::::new(config.clone()); + let mut builder = CircuitBuilder::::new(config); let mut recursive_challenger = RecursiveChallenger::new(&mut builder); let mut recursive_outputs_per_round: Vec> = Vec::new(); for (r, inputs) in inputs_per_round.iter().enumerate() { diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 21dfc28b..540e3e84 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -61,7 +61,7 @@ impl CircuitConfig { #[cfg(test)] pub(crate) fn large_config() -> Self { Self { - num_wires: 126, + num_wires: 143, num_routed_wires: 64, security_bits: 128, rate_bits: 3, diff --git a/src/plonk/recursive_verifier.rs b/src/plonk/recursive_verifier.rs index 3bd50bec..0b371b9d 100644 --- a/src/plonk/recursive_verifier.rs +++ b/src/plonk/recursive_verifier.rs @@ -361,7 +361,7 @@ mod tests { type F = CrandallField; const D: usize = 4; let config = CircuitConfig { - num_wires: 126, + num_wires: 143, num_routed_wires: 33, security_bits: 128, rate_bits: 3, @@ -416,7 +416,7 @@ mod tests { type F = CrandallField; const D: usize = 4; let config = CircuitConfig { - num_wires: 126, + num_wires: 143, num_routed_wires: 64, security_bits: 128, rate_bits: 3,