From 26a222bbdf63f6d481ee7830d14b01c4fe9d9e1e Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Sun, 14 Nov 2021 11:57:36 -0800 Subject: [PATCH] Fewer wires in `PoseidonGate` (#356) Closes #345. --- src/gates/poseidon.rs | 213 ++++++++++++++++++++++++-------------- src/plonk/circuit_data.rs | 2 +- 2 files changed, 135 insertions(+), 80 deletions(-) diff --git a/src/gates/poseidon.rs b/src/gates/poseidon.rs index 6e1eb69a..59c23b44 100644 --- a/src/gates/poseidon.rs +++ b/src/gates/poseidon.rs @@ -56,35 +56,49 @@ where /// is useful for ordering hashes in Merkle proofs. Otherwise, this should be set to 0. pub const WIRE_SWAP: usize = 2 * WIDTH; + const START_DELTA: usize = 2 * WIDTH + 1; + + /// A wire which stores `swap * (input[i + 4] - input[i])`; used to compute the swapped inputs. + fn wire_delta(i: usize) -> usize { + assert!(i < 4); + Self::START_DELTA + i + } + + const START_FULL_0: usize = Self::START_DELTA + 4; + /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the first set /// of full rounds. fn wire_full_sbox_0(round: usize, i: usize) -> usize { + debug_assert!( + round != 0, + "First round S-box inputs are not stored as wires" + ); debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); debug_assert!(i < WIDTH); - 2 * WIDTH + 1 + WIDTH * round + i + Self::START_FULL_0 + WIDTH * (round - 1) + i } + const START_PARTIAL: usize = Self::START_FULL_0 + WIDTH * (poseidon::HALF_N_FULL_ROUNDS - 1); + /// A wire which stores the input of the S-box of the `round`-th round of the partial rounds. fn wire_partial_sbox(round: usize) -> usize { debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); - 2 * WIDTH + 1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS + round + Self::START_PARTIAL + round } + const START_FULL_1: usize = Self::START_PARTIAL + poseidon::N_PARTIAL_ROUNDS; + /// A wire which stores the input of the `i`-th S-box of the `round`-th round of the second set /// of full rounds. fn wire_full_sbox_1(round: usize, i: usize) -> usize { debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); debug_assert!(i < WIDTH); - 2 * WIDTH - + 1 - + WIDTH * (poseidon::HALF_N_FULL_ROUNDS + round) - + poseidon::N_PARTIAL_ROUNDS - + i + Self::START_FULL_1 + WIDTH * round + i } /// End of wire indices, exclusive. fn end() -> usize { - 2 * WIDTH + 1 + WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + Self::START_FULL_1 + WIDTH * poseidon::HALF_N_FULL_ROUNDS } } @@ -104,31 +118,38 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(swap * (swap - F::Extension::ONE)); - let mut state = Vec::with_capacity(WIDTH); + // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. for i in 0..4 { - let a = vars.local_wires[i]; - let b = vars.local_wires[i + 4]; - state.push(a + swap * (b - a)); - } - for i in 0..4 { - let a = vars.local_wires[i + 4]; - let b = vars.local_wires[i]; - state.push(a + swap * (b - a)); - } - for i in 8..WIDTH { - state.push(vars.local_wires[i]); + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + let delta_i = vars.local_wires[Self::wire_delta(i)]; + constraints.push(swap * (input_rhs - input_lhs) - delta_i); + } + + // Compute the possibly-swapped input layer. + let mut state = [F::Extension::ZERO; WIDTH]; + for i in 0..4 { + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let input_lhs = Self::wire_input(i); + let input_rhs = Self::wire_input(i + 4); + state[i] = vars.local_wires[input_lhs] + delta_i; + state[i + 4] = vars.local_wires[input_rhs] - delta_i; + } + for i in 8..WIDTH { + state[i] = vars.local_wires[Self::wire_input(i)]; } - let mut state: [F::Extension; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { >::constant_layer_field(&mut state, round_ctr); - for i in 0..WIDTH { - let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; - constraints.push(state[i] - sbox_in); - state[i] = sbox_in; + if r != 0 { + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } } >::sbox_layer_field(&mut state); state = >::mds_layer_field(&state); @@ -183,31 +204,38 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(swap * swap.sub_one()); - let mut state = Vec::with_capacity(WIDTH); + // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. for i in 0..4 { - let a = vars.local_wires[i]; - let b = vars.local_wires[i + 4]; - state.push(a + swap * (b - a)); - } - for i in 0..4 { - let a = vars.local_wires[i + 4]; - let b = vars.local_wires[i]; - state.push(a + swap * (b - a)); - } - for i in 8..WIDTH { - state.push(vars.local_wires[i]); + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + let delta_i = vars.local_wires[Self::wire_delta(i)]; + constraints.push(swap * (input_rhs - input_lhs) - delta_i); + } + + // Compute the possibly-swapped input layer. + let mut state = [F::ZERO; WIDTH]; + for i in 0..4 { + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let input_lhs = Self::wire_input(i); + let input_rhs = Self::wire_input(i + 4); + state[i] = vars.local_wires[input_lhs] + delta_i; + state[i + 4] = vars.local_wires[input_rhs] - delta_i; + } + for i in 8..WIDTH { + state[i] = vars.local_wires[Self::wire_input(i)]; } - let mut state: [F; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { >::constant_layer(&mut state, round_ctr); - for i in 0..WIDTH { - let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; - constraints.push(state[i] - sbox_in); - state[i] = sbox_in; + if r != 0 { + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(state[i] - sbox_in); + state[i] = sbox_in; + } } >::sbox_layer(&mut state); state = >::mds_layer(&state); @@ -267,38 +295,39 @@ where let swap = vars.local_wires[Self::WIRE_SWAP]; constraints.push(builder.mul_sub_extension(swap, swap, swap)); - let mut state = Vec::with_capacity(WIDTH); - // We need to compute both `if swap {b} else {a}` and `if swap {a} else {b}`. - // We will arithmetize them as - // swap (b - a) + a - // -swap (b - a) + b - // so that `b - a` can be used for both. - let mut state_first_4 = vec![]; - let mut state_next_4 = vec![]; + // Assert that each delta wire is set properly: `delta_i = swap * (rhs - lhs)`. for i in 0..4 { - let a = vars.local_wires[i]; - let b = vars.local_wires[i + 4]; - let delta = builder.sub_extension(b, a); - state_first_4.push(builder.mul_add_extension(swap, delta, a)); - state_next_4.push(builder.arithmetic_extension(F::NEG_ONE, F::ONE, swap, delta, b)); + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let diff = builder.sub_extension(input_rhs, input_lhs); + constraints.push(builder.mul_sub_extension(swap, diff, delta_i)); } - state.extend(state_first_4); - state.extend(state_next_4); + // Compute the possibly-swapped input layer. + let mut state = [builder.zero_extension(); WIDTH]; + for i in 0..4 { + let delta_i = vars.local_wires[Self::wire_delta(i)]; + let input_lhs = vars.local_wires[Self::wire_input(i)]; + let input_rhs = vars.local_wires[Self::wire_input(i + 4)]; + state[i] = builder.add_extension(input_lhs, delta_i); + state[i + 4] = builder.sub_extension(input_rhs, delta_i); + } for i in 8..WIDTH { - state.push(vars.local_wires[i]); + state[i] = vars.local_wires[Self::wire_input(i)]; } - let mut state: [ExtensionTarget; WIDTH] = state.try_into().unwrap(); let mut round_ctr = 0; // First set of full rounds. for r in 0..poseidon::HALF_N_FULL_ROUNDS { >::constant_layer_recursive(builder, &mut state, round_ctr); - for i in 0..WIDTH { - let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; - constraints.push(builder.sub_extension(state[i], sbox_in)); - state[i] = sbox_in; + if r != 0 { + for i in 0..WIDTH { + let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)]; + constraints.push(builder.sub_extension(state[i], sbox_in)); + state[i] = sbox_in; + } } >::sbox_layer_recursive(builder, &mut state); state = >::mds_layer_recursive(builder, &state); @@ -386,7 +415,7 @@ where } fn num_constraints(&self) -> usize { - WIDTH * poseidon::N_FULL_ROUNDS_TOTAL + poseidon::N_PARTIAL_ROUNDS + WIDTH + 1 + WIDTH * (poseidon::N_FULL_ROUNDS_TOTAL - 1) + poseidon::N_PARTIAL_ROUNDS + WIDTH + 1 + 4 } } @@ -422,19 +451,20 @@ where }; let mut state = (0..WIDTH) - .map(|i| { - witness.get_wire(Wire { - gate: self.gate_index, - input: PoseidonGate::::wire_input(i), - }) - }) + .map(|i| witness.get_wire(local_wire(PoseidonGate::::wire_input(i)))) .collect::>(); - let swap_value = witness.get_wire(Wire { - gate: self.gate_index, - input: PoseidonGate::::WIRE_SWAP, - }); + let swap_value = witness.get_wire(local_wire(PoseidonGate::::WIRE_SWAP)); debug_assert!(swap_value == F::ZERO || swap_value == F::ONE); + + for i in 0..4 { + let delta_i = swap_value * (state[i + 4] - state[i]); + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_delta(i)), + delta_i, + ); + } + if swap_value == F::ONE { for i in 0..4 { state.swap(i, 4 + i); @@ -446,11 +476,13 @@ where for r in 0..poseidon::HALF_N_FULL_ROUNDS { >::constant_layer_field(&mut state, round_ctr); - for i in 0..WIDTH { - out_buffer.set_wire( - local_wire(PoseidonGate::::wire_full_sbox_0(r, i)), - state[i], - ); + if r != 0 { + for i in 0..WIDTH { + out_buffer.set_wire( + local_wire(PoseidonGate::::wire_full_sbox_0(r, i)), + state[i], + ); + } } >::sbox_layer_field(&mut state); state = >::mds_layer_field(&state); @@ -522,6 +554,29 @@ mod tests { use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; + #[test] + fn wire_indices() { + type F = GoldilocksField; + const WIDTH: usize = 12; + type Gate = PoseidonGate; + + assert_eq!(Gate::wire_input(0), 0); + assert_eq!(Gate::wire_input(11), 11); + assert_eq!(Gate::wire_output(0), 12); + assert_eq!(Gate::wire_output(11), 23); + assert_eq!(Gate::WIRE_SWAP, 24); + assert_eq!(Gate::wire_delta(0), 25); + assert_eq!(Gate::wire_delta(3), 28); + assert_eq!(Gate::wire_full_sbox_0(1, 0), 29); + assert_eq!(Gate::wire_full_sbox_0(3, 0), 53); + assert_eq!(Gate::wire_full_sbox_0(3, 11), 64); + assert_eq!(Gate::wire_partial_sbox(0), 65); + assert_eq!(Gate::wire_partial_sbox(21), 86); + assert_eq!(Gate::wire_full_sbox_1(0, 0), 87); + assert_eq!(Gate::wire_full_sbox_1(3, 0), 123); + assert_eq!(Gate::wire_full_sbox_1(3, 11), 134); + } + #[test] fn generated_output() { type F = GoldilocksField; diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index d54d327d..564d558d 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -55,7 +55,7 @@ impl CircuitConfig { /// A typical recursion config, without zero-knowledge, targeting ~100 bit security. pub(crate) fn standard_recursion_config() -> Self { Self { - num_wires: 143, + num_wires: 135, num_routed_wires: 25, constant_gate_size: 6, use_base_arithmetic_gate: true,