Move some methods outside impl System (#484)

I didn't really have a good reason for putting there; seems more idiomatic to make them global since they don't need `self`/`Self`.
This commit is contained in:
Daniel Lubarov 2022-02-14 13:47:33 -08:00 committed by GitHub
parent 96c9a2385b
commit 8d699edf21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 282 additions and 286 deletions

View File

@ -1,5 +1,5 @@
use plonky2::field::extension_field::Extendable;
use plonky2::field::field_types::Field;
use plonky2::field::field_types::{Field, PrimeField64};
use plonky2::field::packed_field::PackedField;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
@ -10,7 +10,7 @@ use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsume
use crate::registers::arithmetic::*;
use crate::registers::NUM_COLUMNS;
pub(crate) fn generate_addition<F: RichField>(values: &mut [F; NUM_COLUMNS]) {
pub(crate) fn generate_addition<F: PrimeField64>(values: &mut [F; NUM_COLUMNS]) {
let in_1 = values[COL_ADD_INPUT_1].to_canonical_u64();
let in_2 = values[COL_ADD_INPUT_2].to_canonical_u64();
let in_3 = values[COL_ADD_INPUT_3].to_canonical_u64();

View File

@ -1,5 +1,5 @@
use plonky2::field::extension_field::Extendable;
use plonky2::field::field_types::Field;
use plonky2::field::field_types::{Field, PrimeField64};
use plonky2::field::packed_field::PackedField;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
@ -9,7 +9,7 @@ use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsume
use crate::registers::arithmetic::*;
use crate::registers::NUM_COLUMNS;
pub(crate) fn generate_division<F: RichField>(values: &mut [F; NUM_COLUMNS]) {
pub(crate) fn generate_division<F: PrimeField64>(values: &mut [F; NUM_COLUMNS]) {
// TODO
}

View File

@ -1,5 +1,5 @@
use plonky2::field::extension_field::Extendable;
use plonky2::field::field_types::Field;
use plonky2::field::field_types::{Field, PrimeField64};
use plonky2::field::packed_field::PackedField;
use plonky2::hash::hash_types::RichField;
use plonky2::plonk::circuit_builder::CircuitBuilder;
@ -24,7 +24,7 @@ mod division;
mod multiplication;
mod subtraction;
pub(crate) fn generate_arithmetic_unit<F: RichField>(values: &mut [F; NUM_COLUMNS]) {
pub(crate) fn generate_arithmetic_unit<F: PrimeField64>(values: &mut [F; NUM_COLUMNS]) {
if values[IS_ADD].is_one() {
generate_addition(values);
} else if values[IS_SUB].is_one() {

View File

@ -1,5 +1,5 @@
use plonky2::field::extension_field::Extendable;
use plonky2::field::field_types::Field;
use plonky2::field::field_types::{Field, PrimeField64};
use plonky2::field::packed_field::PackedField;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
@ -9,7 +9,7 @@ use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsume
use crate::registers::arithmetic::*;
use crate::registers::NUM_COLUMNS;
pub(crate) fn generate_multiplication<F: RichField>(values: &mut [F; NUM_COLUMNS]) {
pub(crate) fn generate_multiplication<F: PrimeField64>(values: &mut [F; NUM_COLUMNS]) {
// TODO
}

View File

@ -1,5 +1,5 @@
use plonky2::field::extension_field::Extendable;
use plonky2::field::field_types::Field;
use plonky2::field::field_types::{Field, PrimeField64};
use plonky2::field::packed_field::PackedField;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
@ -9,7 +9,7 @@ use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsume
use crate::registers::arithmetic::*;
use crate::registers::NUM_COLUMNS;
pub(crate) fn generate_subtraction<F: RichField>(values: &mut [F; NUM_COLUMNS]) {
pub(crate) fn generate_subtraction<F: PrimeField64>(values: &mut [F; NUM_COLUMNS]) {
// TODO
}

View File

@ -1,4 +1,5 @@
use plonky2::field::extension_field::{Extendable, FieldExtension};
use plonky2::field::extension_field::Extendable;
use plonky2::field::field_types::{Field, PrimeField64};
use plonky2::field::packed_field::PackedField;
use plonky2::hash::hash_types::RichField;
use plonky2::plonk::circuit_builder::CircuitBuilder;
@ -9,93 +10,84 @@ use starky::vars::StarkEvaluationVars;
use crate::public_input_layout::NUM_PUBLIC_INPUTS;
use crate::registers::core::*;
use crate::registers::NUM_COLUMNS;
use crate::system_zero::SystemZero;
impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
pub(crate) fn generate_first_row_core_registers(&self, first_values: &mut [F; NUM_COLUMNS]) {
first_values[COL_CLOCK] = F::ZERO;
first_values[COL_RANGE_16] = F::ZERO;
first_values[COL_INSTRUCTION_PTR] = F::ZERO;
first_values[COL_FRAME_PTR] = F::ZERO;
first_values[COL_STACK_PTR] = F::ZERO;
}
pub(crate) fn generate_next_row_core_registers(
&self,
local_values: &[F; NUM_COLUMNS],
next_values: &mut [F; NUM_COLUMNS],
) {
// We increment the clock by 1.
next_values[COL_CLOCK] = local_values[COL_CLOCK] + F::ONE;
// We increment the 16-bit table by 1, unless we've reached the max value of 2^16 - 1, in
// which case we repeat that value.
let prev_range_16 = local_values[COL_RANGE_16].to_canonical_u64();
let next_range_16 = (prev_range_16 + 1).min((1 << 16) - 1);
next_values[COL_RANGE_16] = F::from_canonical_u64(next_range_16);
// next_values[COL_INSTRUCTION_PTR] = todo!();
// next_values[COL_FRAME_PTR] = todo!();
// next_values[COL_STACK_PTR] = todo!();
}
#[inline]
pub(crate) fn eval_core_registers<FE, P, const D2: usize>(
&self,
vars: StarkEvaluationVars<FE, P, NUM_COLUMNS, NUM_PUBLIC_INPUTS>,
yield_constr: &mut ConstraintConsumer<P>,
) where
FE: FieldExtension<D2, BaseField = F>,
P: PackedField<Scalar = FE>,
{
// The clock must start with 0, and increment by 1.
let local_clock = vars.local_values[COL_CLOCK];
let next_clock = vars.next_values[COL_CLOCK];
let delta_clock = next_clock - local_clock;
yield_constr.constraint_first_row(local_clock);
yield_constr.constraint(delta_clock - FE::ONE);
// The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1.
let local_range_16 = vars.local_values[COL_RANGE_16];
let next_range_16 = vars.next_values[COL_RANGE_16];
let delta_range_16 = next_range_16 - local_range_16;
yield_constr.constraint_first_row(local_range_16);
yield_constr.constraint_last_row(local_range_16 - FE::from_canonical_u64((1 << 16) - 1));
yield_constr.constraint(delta_range_16 * delta_range_16 - delta_range_16);
// TODO constraints for stack etc.
}
pub(crate) fn eval_core_registers_recursively(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: StarkEvaluationTargets<D, NUM_COLUMNS, NUM_PUBLIC_INPUTS>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let one_ext = builder.one_extension();
let max_u16 = builder.constant(F::from_canonical_u64((1 << 16) - 1));
let max_u16_ext = builder.convert_to_ext(max_u16);
// The clock must start with 0, and increment by 1.
let local_clock = vars.local_values[COL_CLOCK];
let next_clock = vars.next_values[COL_CLOCK];
let delta_clock = builder.sub_extension(next_clock, local_clock);
yield_constr.constraint_first_row(builder, local_clock);
let constraint = builder.sub_extension(delta_clock, one_ext);
yield_constr.constraint(builder, constraint);
// The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1.
let local_range_16 = vars.local_values[COL_RANGE_16];
let next_range_16 = vars.next_values[COL_RANGE_16];
let delta_range_16 = builder.sub_extension(next_range_16, local_range_16);
yield_constr.constraint_first_row(builder, local_range_16);
let constraint = builder.sub_extension(local_range_16, max_u16_ext);
yield_constr.constraint_last_row(builder, constraint);
let constraint = builder.mul_add_extension(delta_range_16, delta_range_16, delta_range_16);
yield_constr.constraint(builder, constraint);
// TODO constraints for stack etc.
}
pub(crate) fn generate_first_row_core_registers<F: Field>(first_values: &mut [F; NUM_COLUMNS]) {
first_values[COL_CLOCK] = F::ZERO;
first_values[COL_RANGE_16] = F::ZERO;
first_values[COL_INSTRUCTION_PTR] = F::ZERO;
first_values[COL_FRAME_PTR] = F::ZERO;
first_values[COL_STACK_PTR] = F::ZERO;
}
pub(crate) fn generate_next_row_core_registers<F: PrimeField64>(
local_values: &[F; NUM_COLUMNS],
next_values: &mut [F; NUM_COLUMNS],
) {
// We increment the clock by 1.
next_values[COL_CLOCK] = local_values[COL_CLOCK] + F::ONE;
// We increment the 16-bit table by 1, unless we've reached the max value of 2^16 - 1, in
// which case we repeat that value.
let prev_range_16 = local_values[COL_RANGE_16].to_canonical_u64();
let next_range_16 = (prev_range_16 + 1).min((1 << 16) - 1);
next_values[COL_RANGE_16] = F::from_canonical_u64(next_range_16);
// next_values[COL_INSTRUCTION_PTR] = todo!();
// next_values[COL_FRAME_PTR] = todo!();
// next_values[COL_STACK_PTR] = todo!();
}
#[inline]
pub(crate) fn eval_core_registers<F: Field, P: PackedField<Scalar = F>>(
vars: StarkEvaluationVars<F, P, NUM_COLUMNS, NUM_PUBLIC_INPUTS>,
yield_constr: &mut ConstraintConsumer<P>,
) {
// The clock must start with 0, and increment by 1.
let local_clock = vars.local_values[COL_CLOCK];
let next_clock = vars.next_values[COL_CLOCK];
let delta_clock = next_clock - local_clock;
yield_constr.constraint_first_row(local_clock);
yield_constr.constraint(delta_clock - F::ONE);
// The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1.
let local_range_16 = vars.local_values[COL_RANGE_16];
let next_range_16 = vars.next_values[COL_RANGE_16];
let delta_range_16 = next_range_16 - local_range_16;
yield_constr.constraint_first_row(local_range_16);
yield_constr.constraint_last_row(local_range_16 - F::from_canonical_u64((1 << 16) - 1));
yield_constr.constraint(delta_range_16 * delta_range_16 - delta_range_16);
// TODO constraints for stack etc.
}
pub(crate) fn eval_core_registers_recursively<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
vars: StarkEvaluationTargets<D, NUM_COLUMNS, NUM_PUBLIC_INPUTS>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let one_ext = builder.one_extension();
let max_u16 = builder.constant(F::from_canonical_u64((1 << 16) - 1));
let max_u16_ext = builder.convert_to_ext(max_u16);
// The clock must start with 0, and increment by 1.
let local_clock = vars.local_values[COL_CLOCK];
let next_clock = vars.next_values[COL_CLOCK];
let delta_clock = builder.sub_extension(next_clock, local_clock);
yield_constr.constraint_first_row(builder, local_clock);
let constraint = builder.sub_extension(delta_clock, one_ext);
yield_constr.constraint(builder, constraint);
// The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1.
let local_range_16 = vars.local_values[COL_RANGE_16];
let next_range_16 = vars.next_values[COL_RANGE_16];
let delta_range_16 = builder.sub_extension(next_range_16, local_range_16);
yield_constr.constraint_first_row(builder, local_range_16);
let constraint = builder.sub_extension(local_range_16, max_u16_ext);
yield_constr.constraint_last_row(builder, constraint);
let constraint = builder.mul_add_extension(delta_range_16, delta_range_16, delta_range_16);
yield_constr.constraint(builder, constraint);
// TODO constraints for stack etc.
}

View File

@ -2,7 +2,7 @@ use plonky2::field::extension_field::{Extendable, FieldExtension};
use plonky2::field::packed_field::PackedField;
use plonky2::hash::hash_types::RichField;
use plonky2::hash::hashing::SPONGE_WIDTH;
use plonky2::hash::poseidon::{HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS};
use plonky2::hash::poseidon::{Poseidon, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use starky::vars::StarkEvaluationTargets;
@ -11,15 +11,14 @@ use starky::vars::StarkEvaluationVars;
use crate::public_input_layout::NUM_PUBLIC_INPUTS;
use crate::registers::permutation::*;
use crate::registers::NUM_COLUMNS;
use crate::system_zero::SystemZero;
fn constant_layer<F, FE, P, const D2: usize>(
fn constant_layer<F, FE, P, const D: usize>(
mut state: [P; SPONGE_WIDTH],
round: usize,
) -> [P; SPONGE_WIDTH]
where
F: RichField,
FE: FieldExtension<D2, BaseField = F>,
F: Poseidon,
FE: FieldExtension<D, BaseField = F>,
P: PackedField<Scalar = FE>,
{
// One day I might actually vectorize this, but today is not that day.
@ -36,10 +35,10 @@ where
state
}
fn mds_layer<F, FE, P, const D2: usize>(mut state: [P; SPONGE_WIDTH]) -> [P; SPONGE_WIDTH]
fn mds_layer<F, FE, P, const D: usize>(mut state: [P; SPONGE_WIDTH]) -> [P; SPONGE_WIDTH]
where
F: RichField,
FE: FieldExtension<D2, BaseField = F>,
F: Poseidon,
FE: FieldExtension<D, BaseField = F>,
P: PackedField<Scalar = FE>,
{
for i in 0..P::WIDTH {
@ -55,206 +54,204 @@ where
state
}
impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
pub(crate) fn generate_permutation_unit(values: &mut [F; NUM_COLUMNS]) {
// Load inputs.
let mut state = [F::ZERO; SPONGE_WIDTH];
pub(crate) fn generate_permutation_unit<F: Poseidon>(values: &mut [F; NUM_COLUMNS]) {
// Load inputs.
let mut state = [F::ZERO; SPONGE_WIDTH];
for i in 0..SPONGE_WIDTH {
state[i] = values[col_input(i)];
}
for r in 0..HALF_N_FULL_ROUNDS {
F::constant_layer(&mut state, r);
for i in 0..SPONGE_WIDTH {
state[i] = values[col_input(i)];
let state_cubed = state[i].cube();
values[col_full_first_mid_sbox(r, i)] = state_cubed;
state[i] *= state_cubed.square(); // Form state ** 7.
}
for r in 0..HALF_N_FULL_ROUNDS {
F::constant_layer(&mut state, r);
state = F::mds_layer(&state);
for i in 0..SPONGE_WIDTH {
let state_cubed = state[i].cube();
values[col_full_first_mid_sbox(r, i)] = state_cubed;
state[i] *= state_cubed.square(); // Form state ** 7.
}
state = F::mds_layer(&state);
for i in 0..SPONGE_WIDTH {
values[col_full_first_after_mds(r, i)] = state[i];
}
}
for r in 0..N_PARTIAL_ROUNDS {
F::constant_layer(&mut state, HALF_N_FULL_ROUNDS + r);
let state0_cubed = state[0].cube();
values[col_partial_mid_sbox(r)] = state0_cubed;
state[0] *= state0_cubed.square(); // Form state ** 7.
values[col_partial_after_sbox(r)] = state[0];
state = F::mds_layer(&state);
}
for r in 0..HALF_N_FULL_ROUNDS {
F::constant_layer(&mut state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r);
for i in 0..SPONGE_WIDTH {
let state_cubed = state[i].cube();
values[col_full_second_mid_sbox(r, i)] = state_cubed;
state[i] *= state_cubed.square(); // Form state ** 7.
}
state = F::mds_layer(&state);
for i in 0..SPONGE_WIDTH {
values[col_full_second_after_mds(r, i)] = state[i];
}
for i in 0..SPONGE_WIDTH {
values[col_full_first_after_mds(r, i)] = state[i];
}
}
#[inline]
pub(crate) fn eval_permutation_unit<FE, P, const D2: usize>(
vars: StarkEvaluationVars<FE, P, NUM_COLUMNS, NUM_PUBLIC_INPUTS>,
yield_constr: &mut ConstraintConsumer<P>,
) where
FE: FieldExtension<D2, BaseField = F>,
P: PackedField<Scalar = FE>,
{
let local_values = &vars.local_values;
for r in 0..N_PARTIAL_ROUNDS {
F::constant_layer(&mut state, HALF_N_FULL_ROUNDS + r);
let state0_cubed = state[0].cube();
values[col_partial_mid_sbox(r)] = state0_cubed;
state[0] *= state0_cubed.square(); // Form state ** 7.
values[col_partial_after_sbox(r)] = state[0];
state = F::mds_layer(&state);
}
for r in 0..HALF_N_FULL_ROUNDS {
F::constant_layer(&mut state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r);
// Load inputs.
let mut state = [P::ZEROS; SPONGE_WIDTH];
for i in 0..SPONGE_WIDTH {
state[i] = local_values[col_input(i)];
let state_cubed = state[i].cube();
values[col_full_second_mid_sbox(r, i)] = state_cubed;
state[i] *= state_cubed.square(); // Form state ** 7.
}
for r in 0..HALF_N_FULL_ROUNDS {
state = constant_layer(state, r);
state = F::mds_layer(&state);
for i in 0..SPONGE_WIDTH {
let state_cubed = state[i] * state[i].square();
yield_constr
.constraint_wrapping(state_cubed - local_values[col_full_first_mid_sbox(r, i)]);
let state_cubed = local_values[col_full_first_mid_sbox(r, i)];
state[i] *= state_cubed.square(); // Form state ** 7.
}
for i in 0..SPONGE_WIDTH {
values[col_full_second_after_mds(r, i)] = state[i];
}
}
}
state = mds_layer(state);
#[inline]
pub(crate) fn eval_permutation_unit<F, FE, P, const D: usize>(
vars: StarkEvaluationVars<FE, P, NUM_COLUMNS, NUM_PUBLIC_INPUTS>,
yield_constr: &mut ConstraintConsumer<P>,
) where
F: Poseidon,
FE: FieldExtension<D, BaseField = F>,
P: PackedField<Scalar = FE>,
{
let local_values = &vars.local_values;
for i in 0..SPONGE_WIDTH {
yield_constr
.constraint_wrapping(state[i] - local_values[col_full_first_after_mds(r, i)]);
state[i] = local_values[col_full_first_after_mds(r, i)];
}
// Load inputs.
let mut state = [P::ZEROS; SPONGE_WIDTH];
for i in 0..SPONGE_WIDTH {
state[i] = local_values[col_input(i)];
}
for r in 0..HALF_N_FULL_ROUNDS {
state = constant_layer(state, r);
for i in 0..SPONGE_WIDTH {
let state_cubed = state[i] * state[i].square();
yield_constr
.constraint_wrapping(state_cubed - local_values[col_full_first_mid_sbox(r, i)]);
let state_cubed = local_values[col_full_first_mid_sbox(r, i)];
state[i] *= state_cubed.square(); // Form state ** 7.
}
for r in 0..N_PARTIAL_ROUNDS {
state = constant_layer(state, HALF_N_FULL_ROUNDS + r);
state = mds_layer(state);
let state0_cubed = state[0] * state[0].square();
yield_constr.constraint_wrapping(state0_cubed - local_values[col_partial_mid_sbox(r)]);
let state0_cubed = local_values[col_partial_mid_sbox(r)];
state[0] *= state0_cubed.square(); // Form state ** 7.
yield_constr.constraint_wrapping(state[0] - local_values[col_partial_after_sbox(r)]);
state[0] = local_values[col_partial_after_sbox(r)];
state = mds_layer(state);
}
for r in 0..HALF_N_FULL_ROUNDS {
state = constant_layer(state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r);
for i in 0..SPONGE_WIDTH {
let state_cubed = state[i] * state[i].square();
yield_constr.constraint_wrapping(
state_cubed - local_values[col_full_second_mid_sbox(r, i)],
);
let state_cubed = local_values[col_full_second_mid_sbox(r, i)];
state[i] *= state_cubed.square(); // Form state ** 7.
}
state = mds_layer(state);
for i in 0..SPONGE_WIDTH {
yield_constr
.constraint_wrapping(state[i] - local_values[col_full_second_after_mds(r, i)]);
state[i] = local_values[col_full_second_after_mds(r, i)];
}
for i in 0..SPONGE_WIDTH {
yield_constr
.constraint_wrapping(state[i] - local_values[col_full_first_after_mds(r, i)]);
state[i] = local_values[col_full_first_after_mds(r, i)];
}
}
pub(crate) fn eval_permutation_unit_recursively(
builder: &mut CircuitBuilder<F, D>,
vars: StarkEvaluationTargets<D, NUM_COLUMNS, NUM_PUBLIC_INPUTS>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let zero = builder.zero_extension();
let local_values = &vars.local_values;
for r in 0..N_PARTIAL_ROUNDS {
state = constant_layer(state, HALF_N_FULL_ROUNDS + r);
let state0_cubed = state[0] * state[0].square();
yield_constr.constraint_wrapping(state0_cubed - local_values[col_partial_mid_sbox(r)]);
let state0_cubed = local_values[col_partial_mid_sbox(r)];
state[0] *= state0_cubed.square(); // Form state ** 7.
yield_constr.constraint_wrapping(state[0] - local_values[col_partial_after_sbox(r)]);
state[0] = local_values[col_partial_after_sbox(r)];
state = mds_layer(state);
}
for r in 0..HALF_N_FULL_ROUNDS {
state = constant_layer(state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r);
// Load inputs.
let mut state = [zero; SPONGE_WIDTH];
for i in 0..SPONGE_WIDTH {
state[i] = local_values[col_input(i)];
let state_cubed = state[i] * state[i].square();
yield_constr
.constraint_wrapping(state_cubed - local_values[col_full_second_mid_sbox(r, i)]);
let state_cubed = local_values[col_full_second_mid_sbox(r, i)];
state[i] *= state_cubed.square(); // Form state ** 7.
}
for r in 0..HALF_N_FULL_ROUNDS {
F::constant_layer_recursive(builder, &mut state, r);
state = mds_layer(state);
for i in 0..SPONGE_WIDTH {
let state_cubed = builder.cube_extension(state[i]);
let diff =
builder.sub_extension(state_cubed, local_values[col_full_first_mid_sbox(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
let state_cubed = local_values[col_full_first_mid_sbox(r, i)];
state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]);
// Form state ** 7.
}
state = F::mds_layer_recursive(builder, &state);
for i in 0..SPONGE_WIDTH {
let diff =
builder.sub_extension(state[i], local_values[col_full_first_after_mds(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
state[i] = local_values[col_full_first_after_mds(r, i)];
}
for i in 0..SPONGE_WIDTH {
yield_constr
.constraint_wrapping(state[i] - local_values[col_full_second_after_mds(r, i)]);
state[i] = local_values[col_full_second_after_mds(r, i)];
}
}
}
for r in 0..N_PARTIAL_ROUNDS {
F::constant_layer_recursive(builder, &mut state, HALF_N_FULL_ROUNDS + r);
pub(crate) fn eval_permutation_unit_recursively<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
vars: StarkEvaluationTargets<D, NUM_COLUMNS, NUM_PUBLIC_INPUTS>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let zero = builder.zero_extension();
let local_values = &vars.local_values;
let state0_cubed = builder.cube_extension(state[0]);
let diff = builder.sub_extension(state0_cubed, local_values[col_partial_mid_sbox(r)]);
// Load inputs.
let mut state = [zero; SPONGE_WIDTH];
for i in 0..SPONGE_WIDTH {
state[i] = local_values[col_input(i)];
}
for r in 0..HALF_N_FULL_ROUNDS {
F::constant_layer_recursive(builder, &mut state, r);
for i in 0..SPONGE_WIDTH {
let state_cubed = builder.cube_extension(state[i]);
let diff =
builder.sub_extension(state_cubed, local_values[col_full_first_mid_sbox(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
let state0_cubed = local_values[col_partial_mid_sbox(r)];
state[0] = builder.mul_many_extension(&[state[0], state0_cubed, state0_cubed]); // Form state ** 7.
let diff = builder.sub_extension(state[0], local_values[col_partial_after_sbox(r)]);
yield_constr.constraint_wrapping(builder, diff);
state[0] = local_values[col_partial_after_sbox(r)];
state = F::mds_layer_recursive(builder, &state);
let state_cubed = local_values[col_full_first_mid_sbox(r, i)];
state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]);
// Form state ** 7.
}
for r in 0..HALF_N_FULL_ROUNDS {
F::constant_layer_recursive(
builder,
&mut state,
HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r,
);
state = F::mds_layer_recursive(builder, &state);
for i in 0..SPONGE_WIDTH {
let state_cubed = builder.cube_extension(state[i]);
let diff = builder
.sub_extension(state_cubed, local_values[col_full_second_mid_sbox(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
let state_cubed = local_values[col_full_second_mid_sbox(r, i)];
state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]);
// Form state ** 7.
}
for i in 0..SPONGE_WIDTH {
let diff =
builder.sub_extension(state[i], local_values[col_full_first_after_mds(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
state[i] = local_values[col_full_first_after_mds(r, i)];
}
}
state = F::mds_layer_recursive(builder, &state);
for r in 0..N_PARTIAL_ROUNDS {
F::constant_layer_recursive(builder, &mut state, HALF_N_FULL_ROUNDS + r);
for i in 0..SPONGE_WIDTH {
let diff =
builder.sub_extension(state[i], local_values[col_full_second_after_mds(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
state[i] = local_values[col_full_second_after_mds(r, i)];
}
let state0_cubed = builder.cube_extension(state[0]);
let diff = builder.sub_extension(state0_cubed, local_values[col_partial_mid_sbox(r)]);
yield_constr.constraint_wrapping(builder, diff);
let state0_cubed = local_values[col_partial_mid_sbox(r)];
state[0] = builder.mul_many_extension(&[state[0], state0_cubed, state0_cubed]); // Form state ** 7.
let diff = builder.sub_extension(state[0], local_values[col_partial_after_sbox(r)]);
yield_constr.constraint_wrapping(builder, diff);
state[0] = local_values[col_partial_after_sbox(r)];
state = F::mds_layer_recursive(builder, &state);
}
for r in 0..HALF_N_FULL_ROUNDS {
F::constant_layer_recursive(
builder,
&mut state,
HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r,
);
for i in 0..SPONGE_WIDTH {
let state_cubed = builder.cube_extension(state[i]);
let diff =
builder.sub_extension(state_cubed, local_values[col_full_second_mid_sbox(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
let state_cubed = local_values[col_full_second_mid_sbox(r, i)];
state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]);
// Form state ** 7.
}
state = F::mds_layer_recursive(builder, &state);
for i in 0..SPONGE_WIDTH {
let diff =
builder.sub_extension(state[i], local_values[col_full_second_after_mds(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
state[i] = local_values[col_full_second_after_mds(r, i)];
}
}
}
@ -269,11 +266,10 @@ mod tests {
use starky::constraint_consumer::ConstraintConsumer;
use starky::vars::StarkEvaluationVars;
use crate::permutation_unit::SPONGE_WIDTH;
use crate::permutation_unit::{eval_permutation_unit, generate_permutation_unit, SPONGE_WIDTH};
use crate::public_input_layout::NUM_PUBLIC_INPUTS;
use crate::registers::permutation::{col_input, col_output};
use crate::registers::NUM_COLUMNS;
use crate::system_zero::SystemZero;
#[test]
fn generate_eval_consistency() {
@ -281,7 +277,7 @@ mod tests {
type F = GoldilocksField;
let mut values = [F::default(); NUM_COLUMNS];
SystemZero::<F, D>::generate_permutation_unit(&mut values);
generate_permutation_unit(&mut values);
let vars = StarkEvaluationVars {
local_values: &values,
@ -295,7 +291,7 @@ mod tests {
GoldilocksField::ONE,
GoldilocksField::ONE,
);
SystemZero::<F, D>::eval_permutation_unit(vars, &mut constrant_consumer);
eval_permutation_unit(vars, &mut constrant_consumer);
for &acc in &constrant_consumer.constraint_accs {
assert_eq!(acc, GoldilocksField::ZERO);
}
@ -318,7 +314,7 @@ mod tests {
for i in 0..SPONGE_WIDTH {
values[col_input(i)] = state[i];
}
SystemZero::<F, D>::generate_permutation_unit(&mut values);
generate_permutation_unit(&mut values);
let mut result = [F::default(); SPONGE_WIDTH];
for i in 0..SPONGE_WIDTH {
result[i] = values[col_output(i)];

View File

@ -12,7 +12,14 @@ use starky::vars::StarkEvaluationVars;
use crate::arithmetic::{
eval_arithmetic_unit, eval_arithmetic_unit_recursively, generate_arithmetic_unit,
};
use crate::core_registers::{
eval_core_registers, eval_core_registers_recursively, generate_first_row_core_registers,
generate_next_row_core_registers,
};
use crate::memory::TransactionMemory;
use crate::permutation_unit::{
eval_permutation_unit, eval_permutation_unit_recursively, generate_permutation_unit,
};
use crate::public_input_layout::NUM_PUBLIC_INPUTS;
use crate::registers::NUM_COLUMNS;
@ -29,16 +36,17 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
let memory = TransactionMemory::default();
let mut row = [F::ZERO; NUM_COLUMNS];
self.generate_first_row_core_registers(&mut row);
Self::generate_permutation_unit(&mut row);
generate_first_row_core_registers(&mut row);
generate_arithmetic_unit(&mut row);
generate_permutation_unit(&mut row);
let mut trace = Vec::with_capacity(MIN_TRACE_ROWS);
loop {
let mut next_row = [F::ZERO; NUM_COLUMNS];
self.generate_next_row_core_registers(&row, &mut next_row);
generate_next_row_core_registers(&row, &mut next_row);
generate_arithmetic_unit(&mut next_row);
Self::generate_permutation_unit(&mut next_row);
generate_permutation_unit(&mut next_row);
trace.push(row);
row = next_row;
@ -74,9 +82,9 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for SystemZero<F,
FE: FieldExtension<D2, BaseField = F>,
P: PackedField<Scalar = FE>,
{
self.eval_core_registers(vars, yield_constr);
eval_core_registers(vars, yield_constr);
eval_arithmetic_unit(vars, yield_constr);
Self::eval_permutation_unit(vars, yield_constr);
eval_permutation_unit::<F, FE, P, D2>(vars, yield_constr);
// TODO: Other units
}
@ -86,9 +94,9 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for SystemZero<F,
vars: StarkEvaluationTargets<D, NUM_COLUMNS, NUM_PUBLIC_INPUTS>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
self.eval_core_registers_recursively(builder, vars, yield_constr);
eval_core_registers_recursively(builder, vars, yield_constr);
eval_arithmetic_unit_recursively(builder, vars, yield_constr);
Self::eval_permutation_unit_recursively(builder, vars, yield_constr);
eval_permutation_unit_recursively(builder, vars, yield_constr);
// TODO: Other units
}