From efb1365021c3e8d784121b8501b57bd22d72c5d6 Mon Sep 17 00:00:00 2001 From: Jakub Nabaglo Date: Mon, 7 Feb 2022 14:29:31 -0800 Subject: [PATCH] Split `system_zero::column_layout` into submodules (#475) --- system_zero/src/column_layout.rs | 149 ++++++++++++++-------------- system_zero/src/permutation_unit.rs | 62 ++++++------ 2 files changed, 104 insertions(+), 107 deletions(-) diff --git a/system_zero/src/column_layout.rs b/system_zero/src/column_layout.rs index 7a9e92e5..fa5d627a 100644 --- a/system_zero/src/column_layout.rs +++ b/system_zero/src/column_layout.rs @@ -1,6 +1,3 @@ -use plonky2::hash::hashing::SPONGE_WIDTH; -use plonky2::hash::poseidon; - //// CORE REGISTERS /// A cycle counter. Starts at 0; increments by 1. @@ -21,87 +18,91 @@ pub(crate) const COL_FRAME_PTR: usize = COL_INSTRUCTION_PTR + 1; pub(crate) const COL_STACK_PTR: usize = COL_FRAME_PTR + 1; //// PERMUTATION UNIT +pub(crate) mod permutation { + use plonky2::hash::hashing::SPONGE_WIDTH; + use plonky2::hash::poseidon; -const START_PERMUTATION_UNIT: usize = COL_STACK_PTR + 1; + const START_UNIT: usize = super::COL_STACK_PTR + 1; -const START_PERMUTATION_FULL_FIRST: usize = START_PERMUTATION_UNIT + SPONGE_WIDTH; + const START_FULL_FIRST: usize = START_UNIT + SPONGE_WIDTH; -pub(crate) const fn col_permutation_full_first_mid_sbox(round: usize, i: usize) -> usize { - debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); - debug_assert!(i < SPONGE_WIDTH); - START_PERMUTATION_FULL_FIRST + 2 * round * SPONGE_WIDTH + i + pub const fn col_full_first_mid_sbox(round: usize, i: usize) -> usize { + debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < SPONGE_WIDTH); + START_FULL_FIRST + 2 * round * SPONGE_WIDTH + i + } + + pub const fn col_full_first_after_mds(round: usize, i: usize) -> usize { + debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < SPONGE_WIDTH); + START_FULL_FIRST + (2 * round + 1) * SPONGE_WIDTH + i + } + + const START_PARTIAL: usize = + col_full_first_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, SPONGE_WIDTH - 1) + 1; + + pub const fn col_partial_mid_sbox(round: usize) -> usize { + debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); + START_PARTIAL + 2 * round + } + + pub const fn col_partial_after_sbox(round: usize) -> usize { + debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); + START_PARTIAL + 2 * round + 1 + } + + const START_FULL_SECOND: usize = col_partial_after_sbox(poseidon::N_PARTIAL_ROUNDS - 1) + 1; + + pub const fn col_full_second_mid_sbox(round: usize, i: usize) -> usize { + debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < SPONGE_WIDTH); + START_FULL_SECOND + 2 * round * SPONGE_WIDTH + i + } + + pub const fn col_full_second_after_mds(round: usize, i: usize) -> usize { + debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < SPONGE_WIDTH); + START_FULL_SECOND + (2 * round + 1) * SPONGE_WIDTH + i + } + + pub const fn col_input(i: usize) -> usize { + debug_assert!(i < SPONGE_WIDTH); + START_UNIT + i + } + + pub const fn col_output(i: usize) -> usize { + debug_assert!(i < SPONGE_WIDTH); + col_full_second_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, i) + } + + pub(super) const END_UNIT: usize = col_output(SPONGE_WIDTH - 1); } -pub(crate) const fn col_permutation_full_first_after_mds(round: usize, i: usize) -> usize { - debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); - debug_assert!(i < SPONGE_WIDTH); - START_PERMUTATION_FULL_FIRST + (2 * round + 1) * SPONGE_WIDTH + i -} - -const START_PERMUTATION_PARTIAL: usize = - col_permutation_full_first_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, SPONGE_WIDTH - 1) + 1; - -pub(crate) const fn col_permutation_partial_mid_sbox(round: usize) -> usize { - debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); - START_PERMUTATION_PARTIAL + 2 * round -} - -pub(crate) const fn col_permutation_partial_after_sbox(round: usize) -> usize { - debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); - START_PERMUTATION_PARTIAL + 2 * round + 1 -} - -const START_PERMUTATION_FULL_SECOND: usize = - col_permutation_partial_after_sbox(poseidon::N_PARTIAL_ROUNDS - 1) + 1; - -pub(crate) const fn col_permutation_full_second_mid_sbox(round: usize, i: usize) -> usize { - debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS); - debug_assert!(i < SPONGE_WIDTH); - START_PERMUTATION_FULL_SECOND + 2 * round * SPONGE_WIDTH + i -} - -pub(crate) const fn col_permutation_full_second_after_mds(round: usize, i: usize) -> usize { - debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS); - debug_assert!(i < SPONGE_WIDTH); - START_PERMUTATION_FULL_SECOND + (2 * round + 1) * SPONGE_WIDTH + i -} - -pub(crate) const fn col_permutation_input(i: usize) -> usize { - debug_assert!(i < SPONGE_WIDTH); - START_PERMUTATION_UNIT + i -} - -pub(crate) const fn col_permutation_output(i: usize) -> usize { - debug_assert!(i < SPONGE_WIDTH); - col_permutation_full_second_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, i) -} - -const END_PERMUTATION_UNIT: usize = col_permutation_output(SPONGE_WIDTH - 1); - //// MEMORY UNITS //// DECOMPOSITION UNITS +pub(crate) mod decomposition { -const START_DECOMPOSITION_UNITS: usize = END_PERMUTATION_UNIT + 1; + const START_UNITS: usize = super::permutation::END_UNIT + 1; -const NUM_DECOMPOSITION_UNITS: usize = 4; -/// The number of bits associated with a single decomposition unit. -const DECOMPOSITION_UNIT_BITS: usize = 32; -/// One column for the value being decomposed, plus one column per bit. -const DECOMPOSITION_UNIT_COLS: usize = 1 + DECOMPOSITION_UNIT_BITS; + const NUM_UNITS: usize = 4; + /// The number of bits associated with a single decomposition unit. + const UNIT_BITS: usize = 32; + /// One column for the value being decomposed, plus one column per bit. + const UNIT_COLS: usize = 1 + UNIT_BITS; -pub(crate) const fn col_decomposition_input(unit: usize) -> usize { - debug_assert!(unit < NUM_DECOMPOSITION_UNITS); - START_DECOMPOSITION_UNITS + unit * DECOMPOSITION_UNIT_COLS + pub const fn col_input(unit: usize) -> usize { + debug_assert!(unit < NUM_UNITS); + START_UNITS + unit * UNIT_COLS + } + + pub const fn col_bit(unit: usize, bit: usize) -> usize { + debug_assert!(unit < NUM_UNITS); + debug_assert!(bit < UNIT_BITS); + START_UNITS + unit * UNIT_COLS + 1 + bit + } + + pub(super) const END_UNITS: usize = START_UNITS + UNIT_COLS * NUM_UNITS; } -pub(crate) const fn col_decomposition_bit(unit: usize, bit: usize) -> usize { - debug_assert!(unit < NUM_DECOMPOSITION_UNITS); - debug_assert!(bit < DECOMPOSITION_UNIT_BITS); - START_DECOMPOSITION_UNITS + unit * DECOMPOSITION_UNIT_COLS + 1 + bit -} - -const END_DECOMPOSITION_UNITS: usize = - START_DECOMPOSITION_UNITS + DECOMPOSITION_UNIT_COLS * NUM_DECOMPOSITION_UNITS; - -pub(crate) const NUM_COLUMNS: usize = END_DECOMPOSITION_UNITS; +pub(crate) const NUM_COLUMNS: usize = decomposition::END_UNITS; diff --git a/system_zero/src/permutation_unit.rs b/system_zero/src/permutation_unit.rs index 7f12b9ce..e15474e4 100644 --- a/system_zero/src/permutation_unit.rs +++ b/system_zero/src/permutation_unit.rs @@ -8,15 +8,11 @@ use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsume use starky::vars::StarkEvaluationTargets; use starky::vars::StarkEvaluationVars; -use crate::column_layout::{ - col_permutation_full_first_after_mds as col_full_1st_after_mds, - col_permutation_full_first_mid_sbox as col_full_1st_mid_sbox, - col_permutation_full_second_after_mds as col_full_2nd_after_mds, - col_permutation_full_second_mid_sbox as col_full_2nd_mid_sbox, - col_permutation_input as col_input, - col_permutation_partial_after_sbox as col_partial_after_sbox, - col_permutation_partial_mid_sbox as col_partial_mid_sbox, NUM_COLUMNS, +use crate::column_layout::permutation::{ + col_full_first_after_mds, col_full_first_mid_sbox, col_full_second_after_mds, + col_full_second_mid_sbox, col_input, col_partial_after_sbox, col_partial_mid_sbox, }; +use crate::column_layout::NUM_COLUMNS; use crate::public_input_layout::NUM_PUBLIC_INPUTS; use crate::system_zero::SystemZero; @@ -75,14 +71,14 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { let state_cubed = state[i].cube(); - values[col_full_1st_mid_sbox(r, i)] = state_cubed; + values[col_full_first_mid_sbox(r, i)] = state_cubed; state[i] *= state_cubed.square(); // Form state ** 7. } state = F::mds_layer(&state); for i in 0..SPONGE_WIDTH { - values[col_full_1st_after_mds(r, i)] = state[i]; + values[col_full_first_after_mds(r, i)] = state[i]; } } @@ -102,14 +98,14 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { let state_cubed = state[i].cube(); - values[col_full_2nd_mid_sbox(r, i)] = state_cubed; + values[col_full_second_mid_sbox(r, i)] = state_cubed; state[i] *= state_cubed.square(); // Form state ** 7. } state = F::mds_layer(&state); for i in 0..SPONGE_WIDTH { - values[col_full_2nd_after_mds(r, i)] = state[i]; + values[col_full_second_after_mds(r, i)] = state[i]; } } } @@ -136,8 +132,8 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { let state_cubed = state[i] * state[i].square(); yield_constr - .constraint_wrapping(state_cubed - local_values[col_full_1st_mid_sbox(r, i)]); - let state_cubed = local_values[col_full_1st_mid_sbox(r, i)]; + .constraint_wrapping(state_cubed - local_values[col_full_first_mid_sbox(r, i)]); + let state_cubed = local_values[col_full_first_mid_sbox(r, i)]; state[i] *= state_cubed.square(); // Form state ** 7. } @@ -145,8 +141,8 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { yield_constr - .constraint_wrapping(state[i] - local_values[col_full_1st_after_mds(r, i)]); - state[i] = local_values[col_full_1st_after_mds(r, i)]; + .constraint_wrapping(state[i] - local_values[col_full_first_after_mds(r, i)]); + state[i] = local_values[col_full_first_after_mds(r, i)]; } } @@ -168,9 +164,10 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { let state_cubed = state[i] * state[i].square(); - yield_constr - .constraint_wrapping(state_cubed - local_values[col_full_2nd_mid_sbox(r, i)]); - let state_cubed = local_values[col_full_2nd_mid_sbox(r, i)]; + yield_constr.constraint_wrapping( + state_cubed - local_values[col_full_second_mid_sbox(r, i)], + ); + let state_cubed = local_values[col_full_second_mid_sbox(r, i)]; state[i] *= state_cubed.square(); // Form state ** 7. } @@ -178,8 +175,8 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { yield_constr - .constraint_wrapping(state[i] - local_values[col_full_2nd_after_mds(r, i)]); - state[i] = local_values[col_full_2nd_after_mds(r, i)]; + .constraint_wrapping(state[i] - local_values[col_full_second_after_mds(r, i)]); + state[i] = local_values[col_full_second_after_mds(r, i)]; } } } @@ -204,9 +201,9 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { let state_cubed = builder.cube_extension(state[i]); let diff = - builder.sub_extension(state_cubed, local_values[col_full_1st_mid_sbox(r, i)]); + builder.sub_extension(state_cubed, local_values[col_full_first_mid_sbox(r, i)]); yield_constr.constraint_wrapping(builder, diff); - let state_cubed = local_values[col_full_1st_mid_sbox(r, i)]; + let state_cubed = local_values[col_full_first_mid_sbox(r, i)]; state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); // Form state ** 7. } @@ -215,9 +212,9 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { let diff = - builder.sub_extension(state[i], local_values[col_full_1st_after_mds(r, i)]); + builder.sub_extension(state[i], local_values[col_full_first_after_mds(r, i)]); yield_constr.constraint_wrapping(builder, diff); - state[i] = local_values[col_full_1st_after_mds(r, i)]; + state[i] = local_values[col_full_first_after_mds(r, i)]; } } @@ -245,10 +242,10 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { let state_cubed = builder.cube_extension(state[i]); - let diff = - builder.sub_extension(state_cubed, local_values[col_full_2nd_mid_sbox(r, i)]); + let diff = builder + .sub_extension(state_cubed, local_values[col_full_second_mid_sbox(r, i)]); yield_constr.constraint_wrapping(builder, diff); - let state_cubed = local_values[col_full_2nd_mid_sbox(r, i)]; + let state_cubed = local_values[col_full_second_mid_sbox(r, i)]; state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); // Form state ** 7. } @@ -257,9 +254,9 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { let diff = - builder.sub_extension(state[i], local_values[col_full_2nd_after_mds(r, i)]); + builder.sub_extension(state[i], local_values[col_full_second_after_mds(r, i)]); yield_constr.constraint_wrapping(builder, diff); - state[i] = local_values[col_full_2nd_after_mds(r, i)]; + state[i] = local_values[col_full_second_after_mds(r, i)]; } } } @@ -275,9 +272,8 @@ mod tests { use starky::constraint_consumer::ConstraintConsumer; use starky::vars::StarkEvaluationVars; - use crate::column_layout::{ - col_permutation_input as col_input, col_permutation_output as col_output, NUM_COLUMNS, - }; + use crate::column_layout::permutation::{col_input, col_output}; + use crate::column_layout::NUM_COLUMNS; use crate::permutation_unit::SPONGE_WIDTH; use crate::public_input_layout::NUM_PUBLIC_INPUTS; use crate::system_zero::SystemZero;