Split system_zero::column_layout into submodules (#475)

This commit is contained in:
Jakub Nabaglo 2022-02-07 14:29:31 -08:00 committed by GitHub
parent 8a07d7af41
commit efb1365021
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 104 additions and 107 deletions

View File

@ -1,6 +1,3 @@
use plonky2::hash::hashing::SPONGE_WIDTH;
use plonky2::hash::poseidon;
//// CORE REGISTERS //// CORE REGISTERS
/// A cycle counter. Starts at 0; increments by 1. /// A cycle counter. Starts at 0; increments by 1.
@ -21,87 +18,91 @@ pub(crate) const COL_FRAME_PTR: usize = COL_INSTRUCTION_PTR + 1;
pub(crate) const COL_STACK_PTR: usize = COL_FRAME_PTR + 1; pub(crate) const COL_STACK_PTR: usize = COL_FRAME_PTR + 1;
//// PERMUTATION UNIT //// PERMUTATION UNIT
pub(crate) mod permutation {
use plonky2::hash::hashing::SPONGE_WIDTH;
use plonky2::hash::poseidon;
const START_PERMUTATION_UNIT: usize = COL_STACK_PTR + 1; const START_UNIT: usize = super::COL_STACK_PTR + 1;
const START_PERMUTATION_FULL_FIRST: usize = START_PERMUTATION_UNIT + SPONGE_WIDTH; const START_FULL_FIRST: usize = START_UNIT + SPONGE_WIDTH;
pub(crate) const fn col_permutation_full_first_mid_sbox(round: usize, i: usize) -> usize { pub const fn col_full_first_mid_sbox(round: usize, i: usize) -> usize {
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH); debug_assert!(i < SPONGE_WIDTH);
START_PERMUTATION_FULL_FIRST + 2 * round * SPONGE_WIDTH + i START_FULL_FIRST + 2 * round * SPONGE_WIDTH + i
} }
pub(crate) const fn col_permutation_full_first_after_mds(round: usize, i: usize) -> usize { pub const fn col_full_first_after_mds(round: usize, i: usize) -> usize {
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH); debug_assert!(i < SPONGE_WIDTH);
START_PERMUTATION_FULL_FIRST + (2 * round + 1) * SPONGE_WIDTH + i START_FULL_FIRST + (2 * round + 1) * SPONGE_WIDTH + i
} }
const START_PERMUTATION_PARTIAL: usize = const START_PARTIAL: usize =
col_permutation_full_first_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, SPONGE_WIDTH - 1) + 1; col_full_first_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, SPONGE_WIDTH - 1) + 1;
pub(crate) const fn col_permutation_partial_mid_sbox(round: usize) -> usize { pub const fn col_partial_mid_sbox(round: usize) -> usize {
debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); debug_assert!(round < poseidon::N_PARTIAL_ROUNDS);
START_PERMUTATION_PARTIAL + 2 * round START_PARTIAL + 2 * round
} }
pub(crate) const fn col_permutation_partial_after_sbox(round: usize) -> usize { pub const fn col_partial_after_sbox(round: usize) -> usize {
debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); debug_assert!(round < poseidon::N_PARTIAL_ROUNDS);
START_PERMUTATION_PARTIAL + 2 * round + 1 START_PARTIAL + 2 * round + 1
} }
const START_PERMUTATION_FULL_SECOND: usize = const START_FULL_SECOND: usize = col_partial_after_sbox(poseidon::N_PARTIAL_ROUNDS - 1) + 1;
col_permutation_partial_after_sbox(poseidon::N_PARTIAL_ROUNDS - 1) + 1;
pub(crate) const fn col_permutation_full_second_mid_sbox(round: usize, i: usize) -> usize { pub const fn col_full_second_mid_sbox(round: usize, i: usize) -> usize {
debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS); debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH); debug_assert!(i < SPONGE_WIDTH);
START_PERMUTATION_FULL_SECOND + 2 * round * SPONGE_WIDTH + i START_FULL_SECOND + 2 * round * SPONGE_WIDTH + i
} }
pub(crate) const fn col_permutation_full_second_after_mds(round: usize, i: usize) -> usize { pub const fn col_full_second_after_mds(round: usize, i: usize) -> usize {
debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS); debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH); debug_assert!(i < SPONGE_WIDTH);
START_PERMUTATION_FULL_SECOND + (2 * round + 1) * SPONGE_WIDTH + i START_FULL_SECOND + (2 * round + 1) * SPONGE_WIDTH + i
} }
pub(crate) const fn col_permutation_input(i: usize) -> usize { pub const fn col_input(i: usize) -> usize {
debug_assert!(i < SPONGE_WIDTH); debug_assert!(i < SPONGE_WIDTH);
START_PERMUTATION_UNIT + i START_UNIT + i
} }
pub(crate) const fn col_permutation_output(i: usize) -> usize { pub const fn col_output(i: usize) -> usize {
debug_assert!(i < SPONGE_WIDTH); debug_assert!(i < SPONGE_WIDTH);
col_permutation_full_second_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, i) col_full_second_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, i)
} }
const END_PERMUTATION_UNIT: usize = col_permutation_output(SPONGE_WIDTH - 1); pub(super) const END_UNIT: usize = col_output(SPONGE_WIDTH - 1);
}
//// MEMORY UNITS //// MEMORY UNITS
//// DECOMPOSITION UNITS //// DECOMPOSITION UNITS
pub(crate) mod decomposition {
const START_DECOMPOSITION_UNITS: usize = END_PERMUTATION_UNIT + 1; const START_UNITS: usize = super::permutation::END_UNIT + 1;
const NUM_DECOMPOSITION_UNITS: usize = 4; const NUM_UNITS: usize = 4;
/// The number of bits associated with a single decomposition unit. /// The number of bits associated with a single decomposition unit.
const DECOMPOSITION_UNIT_BITS: usize = 32; const UNIT_BITS: usize = 32;
/// One column for the value being decomposed, plus one column per bit. /// One column for the value being decomposed, plus one column per bit.
const DECOMPOSITION_UNIT_COLS: usize = 1 + DECOMPOSITION_UNIT_BITS; const UNIT_COLS: usize = 1 + UNIT_BITS;
pub(crate) const fn col_decomposition_input(unit: usize) -> usize { pub const fn col_input(unit: usize) -> usize {
debug_assert!(unit < NUM_DECOMPOSITION_UNITS); debug_assert!(unit < NUM_UNITS);
START_DECOMPOSITION_UNITS + unit * DECOMPOSITION_UNIT_COLS START_UNITS + unit * UNIT_COLS
}
pub const fn col_bit(unit: usize, bit: usize) -> usize {
debug_assert!(unit < NUM_UNITS);
debug_assert!(bit < UNIT_BITS);
START_UNITS + unit * UNIT_COLS + 1 + bit
}
pub(super) const END_UNITS: usize = START_UNITS + UNIT_COLS * NUM_UNITS;
} }
pub(crate) const fn col_decomposition_bit(unit: usize, bit: usize) -> usize { pub(crate) const NUM_COLUMNS: usize = decomposition::END_UNITS;
debug_assert!(unit < NUM_DECOMPOSITION_UNITS);
debug_assert!(bit < DECOMPOSITION_UNIT_BITS);
START_DECOMPOSITION_UNITS + unit * DECOMPOSITION_UNIT_COLS + 1 + bit
}
const END_DECOMPOSITION_UNITS: usize =
START_DECOMPOSITION_UNITS + DECOMPOSITION_UNIT_COLS * NUM_DECOMPOSITION_UNITS;
pub(crate) const NUM_COLUMNS: usize = END_DECOMPOSITION_UNITS;

View File

@ -8,15 +8,11 @@ use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsume
use starky::vars::StarkEvaluationTargets; use starky::vars::StarkEvaluationTargets;
use starky::vars::StarkEvaluationVars; use starky::vars::StarkEvaluationVars;
use crate::column_layout::{ use crate::column_layout::permutation::{
col_permutation_full_first_after_mds as col_full_1st_after_mds, col_full_first_after_mds, col_full_first_mid_sbox, col_full_second_after_mds,
col_permutation_full_first_mid_sbox as col_full_1st_mid_sbox, col_full_second_mid_sbox, col_input, col_partial_after_sbox, col_partial_mid_sbox,
col_permutation_full_second_after_mds as col_full_2nd_after_mds,
col_permutation_full_second_mid_sbox as col_full_2nd_mid_sbox,
col_permutation_input as col_input,
col_permutation_partial_after_sbox as col_partial_after_sbox,
col_permutation_partial_mid_sbox as col_partial_mid_sbox, NUM_COLUMNS,
}; };
use crate::column_layout::NUM_COLUMNS;
use crate::public_input_layout::NUM_PUBLIC_INPUTS; use crate::public_input_layout::NUM_PUBLIC_INPUTS;
use crate::system_zero::SystemZero; use crate::system_zero::SystemZero;
@ -75,14 +71,14 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH { for i in 0..SPONGE_WIDTH {
let state_cubed = state[i].cube(); let state_cubed = state[i].cube();
values[col_full_1st_mid_sbox(r, i)] = state_cubed; values[col_full_first_mid_sbox(r, i)] = state_cubed;
state[i] *= state_cubed.square(); // Form state ** 7. state[i] *= state_cubed.square(); // Form state ** 7.
} }
state = F::mds_layer(&state); state = F::mds_layer(&state);
for i in 0..SPONGE_WIDTH { for i in 0..SPONGE_WIDTH {
values[col_full_1st_after_mds(r, i)] = state[i]; values[col_full_first_after_mds(r, i)] = state[i];
} }
} }
@ -102,14 +98,14 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH { for i in 0..SPONGE_WIDTH {
let state_cubed = state[i].cube(); let state_cubed = state[i].cube();
values[col_full_2nd_mid_sbox(r, i)] = state_cubed; values[col_full_second_mid_sbox(r, i)] = state_cubed;
state[i] *= state_cubed.square(); // Form state ** 7. state[i] *= state_cubed.square(); // Form state ** 7.
} }
state = F::mds_layer(&state); state = F::mds_layer(&state);
for i in 0..SPONGE_WIDTH { for i in 0..SPONGE_WIDTH {
values[col_full_2nd_after_mds(r, i)] = state[i]; values[col_full_second_after_mds(r, i)] = state[i];
} }
} }
} }
@ -136,8 +132,8 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH { for i in 0..SPONGE_WIDTH {
let state_cubed = state[i] * state[i].square(); let state_cubed = state[i] * state[i].square();
yield_constr yield_constr
.constraint_wrapping(state_cubed - local_values[col_full_1st_mid_sbox(r, i)]); .constraint_wrapping(state_cubed - local_values[col_full_first_mid_sbox(r, i)]);
let state_cubed = local_values[col_full_1st_mid_sbox(r, i)]; let state_cubed = local_values[col_full_first_mid_sbox(r, i)];
state[i] *= state_cubed.square(); // Form state ** 7. state[i] *= state_cubed.square(); // Form state ** 7.
} }
@ -145,8 +141,8 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH { for i in 0..SPONGE_WIDTH {
yield_constr yield_constr
.constraint_wrapping(state[i] - local_values[col_full_1st_after_mds(r, i)]); .constraint_wrapping(state[i] - local_values[col_full_first_after_mds(r, i)]);
state[i] = local_values[col_full_1st_after_mds(r, i)]; state[i] = local_values[col_full_first_after_mds(r, i)];
} }
} }
@ -168,9 +164,10 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH { for i in 0..SPONGE_WIDTH {
let state_cubed = state[i] * state[i].square(); let state_cubed = state[i] * state[i].square();
yield_constr yield_constr.constraint_wrapping(
.constraint_wrapping(state_cubed - local_values[col_full_2nd_mid_sbox(r, i)]); state_cubed - local_values[col_full_second_mid_sbox(r, i)],
let state_cubed = local_values[col_full_2nd_mid_sbox(r, i)]; );
let state_cubed = local_values[col_full_second_mid_sbox(r, i)];
state[i] *= state_cubed.square(); // Form state ** 7. state[i] *= state_cubed.square(); // Form state ** 7.
} }
@ -178,8 +175,8 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH { for i in 0..SPONGE_WIDTH {
yield_constr yield_constr
.constraint_wrapping(state[i] - local_values[col_full_2nd_after_mds(r, i)]); .constraint_wrapping(state[i] - local_values[col_full_second_after_mds(r, i)]);
state[i] = local_values[col_full_2nd_after_mds(r, i)]; state[i] = local_values[col_full_second_after_mds(r, i)];
} }
} }
} }
@ -204,9 +201,9 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH { for i in 0..SPONGE_WIDTH {
let state_cubed = builder.cube_extension(state[i]); let state_cubed = builder.cube_extension(state[i]);
let diff = let diff =
builder.sub_extension(state_cubed, local_values[col_full_1st_mid_sbox(r, i)]); builder.sub_extension(state_cubed, local_values[col_full_first_mid_sbox(r, i)]);
yield_constr.constraint_wrapping(builder, diff); yield_constr.constraint_wrapping(builder, diff);
let state_cubed = local_values[col_full_1st_mid_sbox(r, i)]; let state_cubed = local_values[col_full_first_mid_sbox(r, i)];
state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]);
// Form state ** 7. // Form state ** 7.
} }
@ -215,9 +212,9 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH { for i in 0..SPONGE_WIDTH {
let diff = let diff =
builder.sub_extension(state[i], local_values[col_full_1st_after_mds(r, i)]); builder.sub_extension(state[i], local_values[col_full_first_after_mds(r, i)]);
yield_constr.constraint_wrapping(builder, diff); yield_constr.constraint_wrapping(builder, diff);
state[i] = local_values[col_full_1st_after_mds(r, i)]; state[i] = local_values[col_full_first_after_mds(r, i)];
} }
} }
@ -245,10 +242,10 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH { for i in 0..SPONGE_WIDTH {
let state_cubed = builder.cube_extension(state[i]); let state_cubed = builder.cube_extension(state[i]);
let diff = let diff = builder
builder.sub_extension(state_cubed, local_values[col_full_2nd_mid_sbox(r, i)]); .sub_extension(state_cubed, local_values[col_full_second_mid_sbox(r, i)]);
yield_constr.constraint_wrapping(builder, diff); yield_constr.constraint_wrapping(builder, diff);
let state_cubed = local_values[col_full_2nd_mid_sbox(r, i)]; let state_cubed = local_values[col_full_second_mid_sbox(r, i)];
state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]);
// Form state ** 7. // Form state ** 7.
} }
@ -257,9 +254,9 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH { for i in 0..SPONGE_WIDTH {
let diff = let diff =
builder.sub_extension(state[i], local_values[col_full_2nd_after_mds(r, i)]); builder.sub_extension(state[i], local_values[col_full_second_after_mds(r, i)]);
yield_constr.constraint_wrapping(builder, diff); yield_constr.constraint_wrapping(builder, diff);
state[i] = local_values[col_full_2nd_after_mds(r, i)]; state[i] = local_values[col_full_second_after_mds(r, i)];
} }
} }
} }
@ -275,9 +272,8 @@ mod tests {
use starky::constraint_consumer::ConstraintConsumer; use starky::constraint_consumer::ConstraintConsumer;
use starky::vars::StarkEvaluationVars; use starky::vars::StarkEvaluationVars;
use crate::column_layout::{ use crate::column_layout::permutation::{col_input, col_output};
col_permutation_input as col_input, col_permutation_output as col_output, NUM_COLUMNS, use crate::column_layout::NUM_COLUMNS;
};
use crate::permutation_unit::SPONGE_WIDTH; use crate::permutation_unit::SPONGE_WIDTH;
use crate::public_input_layout::NUM_PUBLIC_INPUTS; use crate::public_input_layout::NUM_PUBLIC_INPUTS;
use crate::system_zero::SystemZero; use crate::system_zero::SystemZero;