Split system_zero::column_layout into submodules (#475)

This commit is contained in:
Jakub Nabaglo 2022-02-07 14:29:31 -08:00 committed by GitHub
parent 8a07d7af41
commit efb1365021
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 104 additions and 107 deletions

View File

@ -1,6 +1,3 @@
use plonky2::hash::hashing::SPONGE_WIDTH;
use plonky2::hash::poseidon;
//// CORE REGISTERS
/// A cycle counter. Starts at 0; increments by 1.
@ -21,87 +18,91 @@ pub(crate) const COL_FRAME_PTR: usize = COL_INSTRUCTION_PTR + 1;
pub(crate) const COL_STACK_PTR: usize = COL_FRAME_PTR + 1;
//// PERMUTATION UNIT
pub(crate) mod permutation {
use plonky2::hash::hashing::SPONGE_WIDTH;
use plonky2::hash::poseidon;
const START_PERMUTATION_UNIT: usize = COL_STACK_PTR + 1;
const START_UNIT: usize = super::COL_STACK_PTR + 1;
const START_PERMUTATION_FULL_FIRST: usize = START_PERMUTATION_UNIT + SPONGE_WIDTH;
const START_FULL_FIRST: usize = START_UNIT + SPONGE_WIDTH;
pub(crate) const fn col_permutation_full_first_mid_sbox(round: usize, i: usize) -> usize {
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH);
START_PERMUTATION_FULL_FIRST + 2 * round * SPONGE_WIDTH + i
pub const fn col_full_first_mid_sbox(round: usize, i: usize) -> usize {
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH);
START_FULL_FIRST + 2 * round * SPONGE_WIDTH + i
}
pub const fn col_full_first_after_mds(round: usize, i: usize) -> usize {
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH);
START_FULL_FIRST + (2 * round + 1) * SPONGE_WIDTH + i
}
const START_PARTIAL: usize =
col_full_first_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, SPONGE_WIDTH - 1) + 1;
pub const fn col_partial_mid_sbox(round: usize) -> usize {
debug_assert!(round < poseidon::N_PARTIAL_ROUNDS);
START_PARTIAL + 2 * round
}
pub const fn col_partial_after_sbox(round: usize) -> usize {
debug_assert!(round < poseidon::N_PARTIAL_ROUNDS);
START_PARTIAL + 2 * round + 1
}
const START_FULL_SECOND: usize = col_partial_after_sbox(poseidon::N_PARTIAL_ROUNDS - 1) + 1;
pub const fn col_full_second_mid_sbox(round: usize, i: usize) -> usize {
debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH);
START_FULL_SECOND + 2 * round * SPONGE_WIDTH + i
}
pub const fn col_full_second_after_mds(round: usize, i: usize) -> usize {
debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH);
START_FULL_SECOND + (2 * round + 1) * SPONGE_WIDTH + i
}
pub const fn col_input(i: usize) -> usize {
debug_assert!(i < SPONGE_WIDTH);
START_UNIT + i
}
pub const fn col_output(i: usize) -> usize {
debug_assert!(i < SPONGE_WIDTH);
col_full_second_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, i)
}
pub(super) const END_UNIT: usize = col_output(SPONGE_WIDTH - 1);
}
pub(crate) const fn col_permutation_full_first_after_mds(round: usize, i: usize) -> usize {
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH);
START_PERMUTATION_FULL_FIRST + (2 * round + 1) * SPONGE_WIDTH + i
}
const START_PERMUTATION_PARTIAL: usize =
col_permutation_full_first_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, SPONGE_WIDTH - 1) + 1;
pub(crate) const fn col_permutation_partial_mid_sbox(round: usize) -> usize {
debug_assert!(round < poseidon::N_PARTIAL_ROUNDS);
START_PERMUTATION_PARTIAL + 2 * round
}
pub(crate) const fn col_permutation_partial_after_sbox(round: usize) -> usize {
debug_assert!(round < poseidon::N_PARTIAL_ROUNDS);
START_PERMUTATION_PARTIAL + 2 * round + 1
}
const START_PERMUTATION_FULL_SECOND: usize =
col_permutation_partial_after_sbox(poseidon::N_PARTIAL_ROUNDS - 1) + 1;
pub(crate) const fn col_permutation_full_second_mid_sbox(round: usize, i: usize) -> usize {
debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH);
START_PERMUTATION_FULL_SECOND + 2 * round * SPONGE_WIDTH + i
}
pub(crate) const fn col_permutation_full_second_after_mds(round: usize, i: usize) -> usize {
debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH);
START_PERMUTATION_FULL_SECOND + (2 * round + 1) * SPONGE_WIDTH + i
}
pub(crate) const fn col_permutation_input(i: usize) -> usize {
debug_assert!(i < SPONGE_WIDTH);
START_PERMUTATION_UNIT + i
}
pub(crate) const fn col_permutation_output(i: usize) -> usize {
debug_assert!(i < SPONGE_WIDTH);
col_permutation_full_second_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, i)
}
const END_PERMUTATION_UNIT: usize = col_permutation_output(SPONGE_WIDTH - 1);
//// MEMORY UNITS
//// DECOMPOSITION UNITS
pub(crate) mod decomposition {
const START_DECOMPOSITION_UNITS: usize = END_PERMUTATION_UNIT + 1;
const START_UNITS: usize = super::permutation::END_UNIT + 1;
const NUM_DECOMPOSITION_UNITS: usize = 4;
/// The number of bits associated with a single decomposition unit.
const DECOMPOSITION_UNIT_BITS: usize = 32;
/// One column for the value being decomposed, plus one column per bit.
const DECOMPOSITION_UNIT_COLS: usize = 1 + DECOMPOSITION_UNIT_BITS;
const NUM_UNITS: usize = 4;
/// The number of bits associated with a single decomposition unit.
const UNIT_BITS: usize = 32;
/// One column for the value being decomposed, plus one column per bit.
const UNIT_COLS: usize = 1 + UNIT_BITS;
pub(crate) const fn col_decomposition_input(unit: usize) -> usize {
debug_assert!(unit < NUM_DECOMPOSITION_UNITS);
START_DECOMPOSITION_UNITS + unit * DECOMPOSITION_UNIT_COLS
pub const fn col_input(unit: usize) -> usize {
debug_assert!(unit < NUM_UNITS);
START_UNITS + unit * UNIT_COLS
}
pub const fn col_bit(unit: usize, bit: usize) -> usize {
debug_assert!(unit < NUM_UNITS);
debug_assert!(bit < UNIT_BITS);
START_UNITS + unit * UNIT_COLS + 1 + bit
}
pub(super) const END_UNITS: usize = START_UNITS + UNIT_COLS * NUM_UNITS;
}
pub(crate) const fn col_decomposition_bit(unit: usize, bit: usize) -> usize {
debug_assert!(unit < NUM_DECOMPOSITION_UNITS);
debug_assert!(bit < DECOMPOSITION_UNIT_BITS);
START_DECOMPOSITION_UNITS + unit * DECOMPOSITION_UNIT_COLS + 1 + bit
}
const END_DECOMPOSITION_UNITS: usize =
START_DECOMPOSITION_UNITS + DECOMPOSITION_UNIT_COLS * NUM_DECOMPOSITION_UNITS;
pub(crate) const NUM_COLUMNS: usize = END_DECOMPOSITION_UNITS;
pub(crate) const NUM_COLUMNS: usize = decomposition::END_UNITS;

View File

@ -8,15 +8,11 @@ use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsume
use starky::vars::StarkEvaluationTargets;
use starky::vars::StarkEvaluationVars;
use crate::column_layout::{
col_permutation_full_first_after_mds as col_full_1st_after_mds,
col_permutation_full_first_mid_sbox as col_full_1st_mid_sbox,
col_permutation_full_second_after_mds as col_full_2nd_after_mds,
col_permutation_full_second_mid_sbox as col_full_2nd_mid_sbox,
col_permutation_input as col_input,
col_permutation_partial_after_sbox as col_partial_after_sbox,
col_permutation_partial_mid_sbox as col_partial_mid_sbox, NUM_COLUMNS,
use crate::column_layout::permutation::{
col_full_first_after_mds, col_full_first_mid_sbox, col_full_second_after_mds,
col_full_second_mid_sbox, col_input, col_partial_after_sbox, col_partial_mid_sbox,
};
use crate::column_layout::NUM_COLUMNS;
use crate::public_input_layout::NUM_PUBLIC_INPUTS;
use crate::system_zero::SystemZero;
@ -75,14 +71,14 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH {
let state_cubed = state[i].cube();
values[col_full_1st_mid_sbox(r, i)] = state_cubed;
values[col_full_first_mid_sbox(r, i)] = state_cubed;
state[i] *= state_cubed.square(); // Form state ** 7.
}
state = F::mds_layer(&state);
for i in 0..SPONGE_WIDTH {
values[col_full_1st_after_mds(r, i)] = state[i];
values[col_full_first_after_mds(r, i)] = state[i];
}
}
@ -102,14 +98,14 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH {
let state_cubed = state[i].cube();
values[col_full_2nd_mid_sbox(r, i)] = state_cubed;
values[col_full_second_mid_sbox(r, i)] = state_cubed;
state[i] *= state_cubed.square(); // Form state ** 7.
}
state = F::mds_layer(&state);
for i in 0..SPONGE_WIDTH {
values[col_full_2nd_after_mds(r, i)] = state[i];
values[col_full_second_after_mds(r, i)] = state[i];
}
}
}
@ -136,8 +132,8 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH {
let state_cubed = state[i] * state[i].square();
yield_constr
.constraint_wrapping(state_cubed - local_values[col_full_1st_mid_sbox(r, i)]);
let state_cubed = local_values[col_full_1st_mid_sbox(r, i)];
.constraint_wrapping(state_cubed - local_values[col_full_first_mid_sbox(r, i)]);
let state_cubed = local_values[col_full_first_mid_sbox(r, i)];
state[i] *= state_cubed.square(); // Form state ** 7.
}
@ -145,8 +141,8 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH {
yield_constr
.constraint_wrapping(state[i] - local_values[col_full_1st_after_mds(r, i)]);
state[i] = local_values[col_full_1st_after_mds(r, i)];
.constraint_wrapping(state[i] - local_values[col_full_first_after_mds(r, i)]);
state[i] = local_values[col_full_first_after_mds(r, i)];
}
}
@ -168,9 +164,10 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH {
let state_cubed = state[i] * state[i].square();
yield_constr
.constraint_wrapping(state_cubed - local_values[col_full_2nd_mid_sbox(r, i)]);
let state_cubed = local_values[col_full_2nd_mid_sbox(r, i)];
yield_constr.constraint_wrapping(
state_cubed - local_values[col_full_second_mid_sbox(r, i)],
);
let state_cubed = local_values[col_full_second_mid_sbox(r, i)];
state[i] *= state_cubed.square(); // Form state ** 7.
}
@ -178,8 +175,8 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH {
yield_constr
.constraint_wrapping(state[i] - local_values[col_full_2nd_after_mds(r, i)]);
state[i] = local_values[col_full_2nd_after_mds(r, i)];
.constraint_wrapping(state[i] - local_values[col_full_second_after_mds(r, i)]);
state[i] = local_values[col_full_second_after_mds(r, i)];
}
}
}
@ -204,9 +201,9 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH {
let state_cubed = builder.cube_extension(state[i]);
let diff =
builder.sub_extension(state_cubed, local_values[col_full_1st_mid_sbox(r, i)]);
builder.sub_extension(state_cubed, local_values[col_full_first_mid_sbox(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
let state_cubed = local_values[col_full_1st_mid_sbox(r, i)];
let state_cubed = local_values[col_full_first_mid_sbox(r, i)];
state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]);
// Form state ** 7.
}
@ -215,9 +212,9 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH {
let diff =
builder.sub_extension(state[i], local_values[col_full_1st_after_mds(r, i)]);
builder.sub_extension(state[i], local_values[col_full_first_after_mds(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
state[i] = local_values[col_full_1st_after_mds(r, i)];
state[i] = local_values[col_full_first_after_mds(r, i)];
}
}
@ -245,10 +242,10 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH {
let state_cubed = builder.cube_extension(state[i]);
let diff =
builder.sub_extension(state_cubed, local_values[col_full_2nd_mid_sbox(r, i)]);
let diff = builder
.sub_extension(state_cubed, local_values[col_full_second_mid_sbox(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
let state_cubed = local_values[col_full_2nd_mid_sbox(r, i)];
let state_cubed = local_values[col_full_second_mid_sbox(r, i)];
state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]);
// Form state ** 7.
}
@ -257,9 +254,9 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
for i in 0..SPONGE_WIDTH {
let diff =
builder.sub_extension(state[i], local_values[col_full_2nd_after_mds(r, i)]);
builder.sub_extension(state[i], local_values[col_full_second_after_mds(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
state[i] = local_values[col_full_2nd_after_mds(r, i)];
state[i] = local_values[col_full_second_after_mds(r, i)];
}
}
}
@ -275,9 +272,8 @@ mod tests {
use starky::constraint_consumer::ConstraintConsumer;
use starky::vars::StarkEvaluationVars;
use crate::column_layout::{
col_permutation_input as col_input, col_permutation_output as col_output, NUM_COLUMNS,
};
use crate::column_layout::permutation::{col_input, col_output};
use crate::column_layout::NUM_COLUMNS;
use crate::permutation_unit::SPONGE_WIDTH;
use crate::public_input_layout::NUM_PUBLIC_INPUTS;
use crate::system_zero::SystemZero;