From 8d699edf21a1e7276aa465df0a88595b6df1656b Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Mon, 14 Feb 2022 13:47:33 -0800 Subject: [PATCH] Move some methods outside `impl System` (#484) I didn't really have a good reason for putting there; seems more idiomatic to make them global since they don't need `self`/`Self`. --- system_zero/src/arithmetic/addition.rs | 4 +- system_zero/src/arithmetic/division.rs | 4 +- system_zero/src/arithmetic/mod.rs | 4 +- system_zero/src/arithmetic/multiplication.rs | 4 +- system_zero/src/arithmetic/subtraction.rs | 4 +- system_zero/src/core_registers.rs | 170 +++++---- system_zero/src/permutation_unit.rs | 354 +++++++++---------- system_zero/src/system_zero.rs | 24 +- 8 files changed, 282 insertions(+), 286 deletions(-) diff --git a/system_zero/src/arithmetic/addition.rs b/system_zero/src/arithmetic/addition.rs index 653d533b..7aa0d81a 100644 --- a/system_zero/src/arithmetic/addition.rs +++ b/system_zero/src/arithmetic/addition.rs @@ -1,5 +1,5 @@ use plonky2::field::extension_field::Extendable; -use plonky2::field::field_types::Field; +use plonky2::field::field_types::{Field, PrimeField64}; use plonky2::field::packed_field::PackedField; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; @@ -10,7 +10,7 @@ use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsume use crate::registers::arithmetic::*; use crate::registers::NUM_COLUMNS; -pub(crate) fn generate_addition(values: &mut [F; NUM_COLUMNS]) { +pub(crate) fn generate_addition(values: &mut [F; NUM_COLUMNS]) { let in_1 = values[COL_ADD_INPUT_1].to_canonical_u64(); let in_2 = values[COL_ADD_INPUT_2].to_canonical_u64(); let in_3 = values[COL_ADD_INPUT_3].to_canonical_u64(); diff --git a/system_zero/src/arithmetic/division.rs b/system_zero/src/arithmetic/division.rs index 2f15b233..e91288b9 100644 --- a/system_zero/src/arithmetic/division.rs +++ b/system_zero/src/arithmetic/division.rs @@ -1,5 +1,5 @@ use plonky2::field::extension_field::Extendable; -use plonky2::field::field_types::Field; +use plonky2::field::field_types::{Field, PrimeField64}; use plonky2::field::packed_field::PackedField; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; @@ -9,7 +9,7 @@ use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsume use crate::registers::arithmetic::*; use crate::registers::NUM_COLUMNS; -pub(crate) fn generate_division(values: &mut [F; NUM_COLUMNS]) { +pub(crate) fn generate_division(values: &mut [F; NUM_COLUMNS]) { // TODO } diff --git a/system_zero/src/arithmetic/mod.rs b/system_zero/src/arithmetic/mod.rs index 45a9f7d9..a2b3a4f8 100644 --- a/system_zero/src/arithmetic/mod.rs +++ b/system_zero/src/arithmetic/mod.rs @@ -1,5 +1,5 @@ use plonky2::field::extension_field::Extendable; -use plonky2::field::field_types::Field; +use plonky2::field::field_types::{Field, PrimeField64}; use plonky2::field::packed_field::PackedField; use plonky2::hash::hash_types::RichField; use plonky2::plonk::circuit_builder::CircuitBuilder; @@ -24,7 +24,7 @@ mod division; mod multiplication; mod subtraction; -pub(crate) fn generate_arithmetic_unit(values: &mut [F; NUM_COLUMNS]) { +pub(crate) fn generate_arithmetic_unit(values: &mut [F; NUM_COLUMNS]) { if values[IS_ADD].is_one() { generate_addition(values); } else if values[IS_SUB].is_one() { diff --git a/system_zero/src/arithmetic/multiplication.rs b/system_zero/src/arithmetic/multiplication.rs index 2eefad38..70c181d8 100644 --- a/system_zero/src/arithmetic/multiplication.rs +++ b/system_zero/src/arithmetic/multiplication.rs @@ -1,5 +1,5 @@ use plonky2::field::extension_field::Extendable; -use plonky2::field::field_types::Field; +use plonky2::field::field_types::{Field, PrimeField64}; use plonky2::field::packed_field::PackedField; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; @@ -9,7 +9,7 @@ use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsume use crate::registers::arithmetic::*; use crate::registers::NUM_COLUMNS; -pub(crate) fn generate_multiplication(values: &mut [F; NUM_COLUMNS]) { +pub(crate) fn generate_multiplication(values: &mut [F; NUM_COLUMNS]) { // TODO } diff --git a/system_zero/src/arithmetic/subtraction.rs b/system_zero/src/arithmetic/subtraction.rs index 3613dee6..267bac72 100644 --- a/system_zero/src/arithmetic/subtraction.rs +++ b/system_zero/src/arithmetic/subtraction.rs @@ -1,5 +1,5 @@ use plonky2::field::extension_field::Extendable; -use plonky2::field::field_types::Field; +use plonky2::field::field_types::{Field, PrimeField64}; use plonky2::field::packed_field::PackedField; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; @@ -9,7 +9,7 @@ use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsume use crate::registers::arithmetic::*; use crate::registers::NUM_COLUMNS; -pub(crate) fn generate_subtraction(values: &mut [F; NUM_COLUMNS]) { +pub(crate) fn generate_subtraction(values: &mut [F; NUM_COLUMNS]) { // TODO } diff --git a/system_zero/src/core_registers.rs b/system_zero/src/core_registers.rs index 03e7fa04..c8c6533b 100644 --- a/system_zero/src/core_registers.rs +++ b/system_zero/src/core_registers.rs @@ -1,4 +1,5 @@ -use plonky2::field::extension_field::{Extendable, FieldExtension}; +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::{Field, PrimeField64}; use plonky2::field::packed_field::PackedField; use plonky2::hash::hash_types::RichField; use plonky2::plonk::circuit_builder::CircuitBuilder; @@ -9,93 +10,84 @@ use starky::vars::StarkEvaluationVars; use crate::public_input_layout::NUM_PUBLIC_INPUTS; use crate::registers::core::*; use crate::registers::NUM_COLUMNS; -use crate::system_zero::SystemZero; -impl, const D: usize> SystemZero { - pub(crate) fn generate_first_row_core_registers(&self, first_values: &mut [F; NUM_COLUMNS]) { - first_values[COL_CLOCK] = F::ZERO; - first_values[COL_RANGE_16] = F::ZERO; - first_values[COL_INSTRUCTION_PTR] = F::ZERO; - first_values[COL_FRAME_PTR] = F::ZERO; - first_values[COL_STACK_PTR] = F::ZERO; - } - - pub(crate) fn generate_next_row_core_registers( - &self, - local_values: &[F; NUM_COLUMNS], - next_values: &mut [F; NUM_COLUMNS], - ) { - // We increment the clock by 1. - next_values[COL_CLOCK] = local_values[COL_CLOCK] + F::ONE; - - // We increment the 16-bit table by 1, unless we've reached the max value of 2^16 - 1, in - // which case we repeat that value. - let prev_range_16 = local_values[COL_RANGE_16].to_canonical_u64(); - let next_range_16 = (prev_range_16 + 1).min((1 << 16) - 1); - next_values[COL_RANGE_16] = F::from_canonical_u64(next_range_16); - - // next_values[COL_INSTRUCTION_PTR] = todo!(); - - // next_values[COL_FRAME_PTR] = todo!(); - - // next_values[COL_STACK_PTR] = todo!(); - } - - #[inline] - pub(crate) fn eval_core_registers( - &self, - vars: StarkEvaluationVars, - yield_constr: &mut ConstraintConsumer

, - ) where - FE: FieldExtension, - P: PackedField, - { - // The clock must start with 0, and increment by 1. - let local_clock = vars.local_values[COL_CLOCK]; - let next_clock = vars.next_values[COL_CLOCK]; - let delta_clock = next_clock - local_clock; - yield_constr.constraint_first_row(local_clock); - yield_constr.constraint(delta_clock - FE::ONE); - - // The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1. - let local_range_16 = vars.local_values[COL_RANGE_16]; - let next_range_16 = vars.next_values[COL_RANGE_16]; - let delta_range_16 = next_range_16 - local_range_16; - yield_constr.constraint_first_row(local_range_16); - yield_constr.constraint_last_row(local_range_16 - FE::from_canonical_u64((1 << 16) - 1)); - yield_constr.constraint(delta_range_16 * delta_range_16 - delta_range_16); - - // TODO constraints for stack etc. - } - - pub(crate) fn eval_core_registers_recursively( - &self, - builder: &mut CircuitBuilder, - vars: StarkEvaluationTargets, - yield_constr: &mut RecursiveConstraintConsumer, - ) { - let one_ext = builder.one_extension(); - let max_u16 = builder.constant(F::from_canonical_u64((1 << 16) - 1)); - let max_u16_ext = builder.convert_to_ext(max_u16); - - // The clock must start with 0, and increment by 1. - let local_clock = vars.local_values[COL_CLOCK]; - let next_clock = vars.next_values[COL_CLOCK]; - let delta_clock = builder.sub_extension(next_clock, local_clock); - yield_constr.constraint_first_row(builder, local_clock); - let constraint = builder.sub_extension(delta_clock, one_ext); - yield_constr.constraint(builder, constraint); - - // The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1. - let local_range_16 = vars.local_values[COL_RANGE_16]; - let next_range_16 = vars.next_values[COL_RANGE_16]; - let delta_range_16 = builder.sub_extension(next_range_16, local_range_16); - yield_constr.constraint_first_row(builder, local_range_16); - let constraint = builder.sub_extension(local_range_16, max_u16_ext); - yield_constr.constraint_last_row(builder, constraint); - let constraint = builder.mul_add_extension(delta_range_16, delta_range_16, delta_range_16); - yield_constr.constraint(builder, constraint); - - // TODO constraints for stack etc. - } +pub(crate) fn generate_first_row_core_registers(first_values: &mut [F; NUM_COLUMNS]) { + first_values[COL_CLOCK] = F::ZERO; + first_values[COL_RANGE_16] = F::ZERO; + first_values[COL_INSTRUCTION_PTR] = F::ZERO; + first_values[COL_FRAME_PTR] = F::ZERO; + first_values[COL_STACK_PTR] = F::ZERO; +} + +pub(crate) fn generate_next_row_core_registers( + local_values: &[F; NUM_COLUMNS], + next_values: &mut [F; NUM_COLUMNS], +) { + // We increment the clock by 1. + next_values[COL_CLOCK] = local_values[COL_CLOCK] + F::ONE; + + // We increment the 16-bit table by 1, unless we've reached the max value of 2^16 - 1, in + // which case we repeat that value. + let prev_range_16 = local_values[COL_RANGE_16].to_canonical_u64(); + let next_range_16 = (prev_range_16 + 1).min((1 << 16) - 1); + next_values[COL_RANGE_16] = F::from_canonical_u64(next_range_16); + + // next_values[COL_INSTRUCTION_PTR] = todo!(); + + // next_values[COL_FRAME_PTR] = todo!(); + + // next_values[COL_STACK_PTR] = todo!(); +} + +#[inline] +pub(crate) fn eval_core_registers>( + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, +) { + // The clock must start with 0, and increment by 1. + let local_clock = vars.local_values[COL_CLOCK]; + let next_clock = vars.next_values[COL_CLOCK]; + let delta_clock = next_clock - local_clock; + yield_constr.constraint_first_row(local_clock); + yield_constr.constraint(delta_clock - F::ONE); + + // The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1. + let local_range_16 = vars.local_values[COL_RANGE_16]; + let next_range_16 = vars.next_values[COL_RANGE_16]; + let delta_range_16 = next_range_16 - local_range_16; + yield_constr.constraint_first_row(local_range_16); + yield_constr.constraint_last_row(local_range_16 - F::from_canonical_u64((1 << 16) - 1)); + yield_constr.constraint(delta_range_16 * delta_range_16 - delta_range_16); + + // TODO constraints for stack etc. +} + +pub(crate) fn eval_core_registers_recursively, const D: usize>( + builder: &mut CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let one_ext = builder.one_extension(); + let max_u16 = builder.constant(F::from_canonical_u64((1 << 16) - 1)); + let max_u16_ext = builder.convert_to_ext(max_u16); + + // The clock must start with 0, and increment by 1. + let local_clock = vars.local_values[COL_CLOCK]; + let next_clock = vars.next_values[COL_CLOCK]; + let delta_clock = builder.sub_extension(next_clock, local_clock); + yield_constr.constraint_first_row(builder, local_clock); + let constraint = builder.sub_extension(delta_clock, one_ext); + yield_constr.constraint(builder, constraint); + + // The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1. + let local_range_16 = vars.local_values[COL_RANGE_16]; + let next_range_16 = vars.next_values[COL_RANGE_16]; + let delta_range_16 = builder.sub_extension(next_range_16, local_range_16); + yield_constr.constraint_first_row(builder, local_range_16); + let constraint = builder.sub_extension(local_range_16, max_u16_ext); + yield_constr.constraint_last_row(builder, constraint); + let constraint = builder.mul_add_extension(delta_range_16, delta_range_16, delta_range_16); + yield_constr.constraint(builder, constraint); + + // TODO constraints for stack etc. } diff --git a/system_zero/src/permutation_unit.rs b/system_zero/src/permutation_unit.rs index 2681f2d9..366cff65 100644 --- a/system_zero/src/permutation_unit.rs +++ b/system_zero/src/permutation_unit.rs @@ -2,7 +2,7 @@ use plonky2::field::extension_field::{Extendable, FieldExtension}; use plonky2::field::packed_field::PackedField; use plonky2::hash::hash_types::RichField; use plonky2::hash::hashing::SPONGE_WIDTH; -use plonky2::hash::poseidon::{HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS}; +use plonky2::hash::poseidon::{Poseidon, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS}; use plonky2::plonk::circuit_builder::CircuitBuilder; use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use starky::vars::StarkEvaluationTargets; @@ -11,15 +11,14 @@ use starky::vars::StarkEvaluationVars; use crate::public_input_layout::NUM_PUBLIC_INPUTS; use crate::registers::permutation::*; use crate::registers::NUM_COLUMNS; -use crate::system_zero::SystemZero; -fn constant_layer( +fn constant_layer( mut state: [P; SPONGE_WIDTH], round: usize, ) -> [P; SPONGE_WIDTH] where - F: RichField, - FE: FieldExtension, + F: Poseidon, + FE: FieldExtension, P: PackedField, { // One day I might actually vectorize this, but today is not that day. @@ -36,10 +35,10 @@ where state } -fn mds_layer(mut state: [P; SPONGE_WIDTH]) -> [P; SPONGE_WIDTH] +fn mds_layer(mut state: [P; SPONGE_WIDTH]) -> [P; SPONGE_WIDTH] where - F: RichField, - FE: FieldExtension, + F: Poseidon, + FE: FieldExtension, P: PackedField, { for i in 0..P::WIDTH { @@ -55,206 +54,204 @@ where state } -impl, const D: usize> SystemZero { - pub(crate) fn generate_permutation_unit(values: &mut [F; NUM_COLUMNS]) { - // Load inputs. - let mut state = [F::ZERO; SPONGE_WIDTH]; +pub(crate) fn generate_permutation_unit(values: &mut [F; NUM_COLUMNS]) { + // Load inputs. + let mut state = [F::ZERO; SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { + state[i] = values[col_input(i)]; + } + + for r in 0..HALF_N_FULL_ROUNDS { + F::constant_layer(&mut state, r); + for i in 0..SPONGE_WIDTH { - state[i] = values[col_input(i)]; + let state_cubed = state[i].cube(); + values[col_full_first_mid_sbox(r, i)] = state_cubed; + state[i] *= state_cubed.square(); // Form state ** 7. } - for r in 0..HALF_N_FULL_ROUNDS { - F::constant_layer(&mut state, r); + state = F::mds_layer(&state); - for i in 0..SPONGE_WIDTH { - let state_cubed = state[i].cube(); - values[col_full_first_mid_sbox(r, i)] = state_cubed; - state[i] *= state_cubed.square(); // Form state ** 7. - } - - state = F::mds_layer(&state); - - for i in 0..SPONGE_WIDTH { - values[col_full_first_after_mds(r, i)] = state[i]; - } - } - - for r in 0..N_PARTIAL_ROUNDS { - F::constant_layer(&mut state, HALF_N_FULL_ROUNDS + r); - - let state0_cubed = state[0].cube(); - values[col_partial_mid_sbox(r)] = state0_cubed; - state[0] *= state0_cubed.square(); // Form state ** 7. - values[col_partial_after_sbox(r)] = state[0]; - - state = F::mds_layer(&state); - } - - for r in 0..HALF_N_FULL_ROUNDS { - F::constant_layer(&mut state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r); - - for i in 0..SPONGE_WIDTH { - let state_cubed = state[i].cube(); - values[col_full_second_mid_sbox(r, i)] = state_cubed; - state[i] *= state_cubed.square(); // Form state ** 7. - } - - state = F::mds_layer(&state); - - for i in 0..SPONGE_WIDTH { - values[col_full_second_after_mds(r, i)] = state[i]; - } + for i in 0..SPONGE_WIDTH { + values[col_full_first_after_mds(r, i)] = state[i]; } } - #[inline] - pub(crate) fn eval_permutation_unit( - vars: StarkEvaluationVars, - yield_constr: &mut ConstraintConsumer

, - ) where - FE: FieldExtension, - P: PackedField, - { - let local_values = &vars.local_values; + for r in 0..N_PARTIAL_ROUNDS { + F::constant_layer(&mut state, HALF_N_FULL_ROUNDS + r); + + let state0_cubed = state[0].cube(); + values[col_partial_mid_sbox(r)] = state0_cubed; + state[0] *= state0_cubed.square(); // Form state ** 7. + values[col_partial_after_sbox(r)] = state[0]; + + state = F::mds_layer(&state); + } + + for r in 0..HALF_N_FULL_ROUNDS { + F::constant_layer(&mut state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r); - // Load inputs. - let mut state = [P::ZEROS; SPONGE_WIDTH]; for i in 0..SPONGE_WIDTH { - state[i] = local_values[col_input(i)]; + let state_cubed = state[i].cube(); + values[col_full_second_mid_sbox(r, i)] = state_cubed; + state[i] *= state_cubed.square(); // Form state ** 7. } - for r in 0..HALF_N_FULL_ROUNDS { - state = constant_layer(state, r); + state = F::mds_layer(&state); - for i in 0..SPONGE_WIDTH { - let state_cubed = state[i] * state[i].square(); - yield_constr - .constraint_wrapping(state_cubed - local_values[col_full_first_mid_sbox(r, i)]); - let state_cubed = local_values[col_full_first_mid_sbox(r, i)]; - state[i] *= state_cubed.square(); // Form state ** 7. - } + for i in 0..SPONGE_WIDTH { + values[col_full_second_after_mds(r, i)] = state[i]; + } + } +} - state = mds_layer(state); +#[inline] +pub(crate) fn eval_permutation_unit( + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, +) where + F: Poseidon, + FE: FieldExtension, + P: PackedField, +{ + let local_values = &vars.local_values; - for i in 0..SPONGE_WIDTH { - yield_constr - .constraint_wrapping(state[i] - local_values[col_full_first_after_mds(r, i)]); - state[i] = local_values[col_full_first_after_mds(r, i)]; - } + // Load inputs. + let mut state = [P::ZEROS; SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { + state[i] = local_values[col_input(i)]; + } + + for r in 0..HALF_N_FULL_ROUNDS { + state = constant_layer(state, r); + + for i in 0..SPONGE_WIDTH { + let state_cubed = state[i] * state[i].square(); + yield_constr + .constraint_wrapping(state_cubed - local_values[col_full_first_mid_sbox(r, i)]); + let state_cubed = local_values[col_full_first_mid_sbox(r, i)]; + state[i] *= state_cubed.square(); // Form state ** 7. } - for r in 0..N_PARTIAL_ROUNDS { - state = constant_layer(state, HALF_N_FULL_ROUNDS + r); + state = mds_layer(state); - let state0_cubed = state[0] * state[0].square(); - yield_constr.constraint_wrapping(state0_cubed - local_values[col_partial_mid_sbox(r)]); - let state0_cubed = local_values[col_partial_mid_sbox(r)]; - state[0] *= state0_cubed.square(); // Form state ** 7. - yield_constr.constraint_wrapping(state[0] - local_values[col_partial_after_sbox(r)]); - state[0] = local_values[col_partial_after_sbox(r)]; - - state = mds_layer(state); - } - - for r in 0..HALF_N_FULL_ROUNDS { - state = constant_layer(state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r); - - for i in 0..SPONGE_WIDTH { - let state_cubed = state[i] * state[i].square(); - yield_constr.constraint_wrapping( - state_cubed - local_values[col_full_second_mid_sbox(r, i)], - ); - let state_cubed = local_values[col_full_second_mid_sbox(r, i)]; - state[i] *= state_cubed.square(); // Form state ** 7. - } - - state = mds_layer(state); - - for i in 0..SPONGE_WIDTH { - yield_constr - .constraint_wrapping(state[i] - local_values[col_full_second_after_mds(r, i)]); - state[i] = local_values[col_full_second_after_mds(r, i)]; - } + for i in 0..SPONGE_WIDTH { + yield_constr + .constraint_wrapping(state[i] - local_values[col_full_first_after_mds(r, i)]); + state[i] = local_values[col_full_first_after_mds(r, i)]; } } - pub(crate) fn eval_permutation_unit_recursively( - builder: &mut CircuitBuilder, - vars: StarkEvaluationTargets, - yield_constr: &mut RecursiveConstraintConsumer, - ) { - let zero = builder.zero_extension(); - let local_values = &vars.local_values; + for r in 0..N_PARTIAL_ROUNDS { + state = constant_layer(state, HALF_N_FULL_ROUNDS + r); + + let state0_cubed = state[0] * state[0].square(); + yield_constr.constraint_wrapping(state0_cubed - local_values[col_partial_mid_sbox(r)]); + let state0_cubed = local_values[col_partial_mid_sbox(r)]; + state[0] *= state0_cubed.square(); // Form state ** 7. + yield_constr.constraint_wrapping(state[0] - local_values[col_partial_after_sbox(r)]); + state[0] = local_values[col_partial_after_sbox(r)]; + + state = mds_layer(state); + } + + for r in 0..HALF_N_FULL_ROUNDS { + state = constant_layer(state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r); - // Load inputs. - let mut state = [zero; SPONGE_WIDTH]; for i in 0..SPONGE_WIDTH { - state[i] = local_values[col_input(i)]; + let state_cubed = state[i] * state[i].square(); + yield_constr + .constraint_wrapping(state_cubed - local_values[col_full_second_mid_sbox(r, i)]); + let state_cubed = local_values[col_full_second_mid_sbox(r, i)]; + state[i] *= state_cubed.square(); // Form state ** 7. } - for r in 0..HALF_N_FULL_ROUNDS { - F::constant_layer_recursive(builder, &mut state, r); + state = mds_layer(state); - for i in 0..SPONGE_WIDTH { - let state_cubed = builder.cube_extension(state[i]); - let diff = - builder.sub_extension(state_cubed, local_values[col_full_first_mid_sbox(r, i)]); - yield_constr.constraint_wrapping(builder, diff); - let state_cubed = local_values[col_full_first_mid_sbox(r, i)]; - state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); - // Form state ** 7. - } - - state = F::mds_layer_recursive(builder, &state); - - for i in 0..SPONGE_WIDTH { - let diff = - builder.sub_extension(state[i], local_values[col_full_first_after_mds(r, i)]); - yield_constr.constraint_wrapping(builder, diff); - state[i] = local_values[col_full_first_after_mds(r, i)]; - } + for i in 0..SPONGE_WIDTH { + yield_constr + .constraint_wrapping(state[i] - local_values[col_full_second_after_mds(r, i)]); + state[i] = local_values[col_full_second_after_mds(r, i)]; } + } +} - for r in 0..N_PARTIAL_ROUNDS { - F::constant_layer_recursive(builder, &mut state, HALF_N_FULL_ROUNDS + r); +pub(crate) fn eval_permutation_unit_recursively, const D: usize>( + builder: &mut CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let zero = builder.zero_extension(); + let local_values = &vars.local_values; - let state0_cubed = builder.cube_extension(state[0]); - let diff = builder.sub_extension(state0_cubed, local_values[col_partial_mid_sbox(r)]); + // Load inputs. + let mut state = [zero; SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { + state[i] = local_values[col_input(i)]; + } + + for r in 0..HALF_N_FULL_ROUNDS { + F::constant_layer_recursive(builder, &mut state, r); + + for i in 0..SPONGE_WIDTH { + let state_cubed = builder.cube_extension(state[i]); + let diff = + builder.sub_extension(state_cubed, local_values[col_full_first_mid_sbox(r, i)]); yield_constr.constraint_wrapping(builder, diff); - let state0_cubed = local_values[col_partial_mid_sbox(r)]; - state[0] = builder.mul_many_extension(&[state[0], state0_cubed, state0_cubed]); // Form state ** 7. - let diff = builder.sub_extension(state[0], local_values[col_partial_after_sbox(r)]); - yield_constr.constraint_wrapping(builder, diff); - state[0] = local_values[col_partial_after_sbox(r)]; - - state = F::mds_layer_recursive(builder, &state); + let state_cubed = local_values[col_full_first_mid_sbox(r, i)]; + state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); + // Form state ** 7. } - for r in 0..HALF_N_FULL_ROUNDS { - F::constant_layer_recursive( - builder, - &mut state, - HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r, - ); + state = F::mds_layer_recursive(builder, &state); - for i in 0..SPONGE_WIDTH { - let state_cubed = builder.cube_extension(state[i]); - let diff = builder - .sub_extension(state_cubed, local_values[col_full_second_mid_sbox(r, i)]); - yield_constr.constraint_wrapping(builder, diff); - let state_cubed = local_values[col_full_second_mid_sbox(r, i)]; - state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); - // Form state ** 7. - } + for i in 0..SPONGE_WIDTH { + let diff = + builder.sub_extension(state[i], local_values[col_full_first_after_mds(r, i)]); + yield_constr.constraint_wrapping(builder, diff); + state[i] = local_values[col_full_first_after_mds(r, i)]; + } + } - state = F::mds_layer_recursive(builder, &state); + for r in 0..N_PARTIAL_ROUNDS { + F::constant_layer_recursive(builder, &mut state, HALF_N_FULL_ROUNDS + r); - for i in 0..SPONGE_WIDTH { - let diff = - builder.sub_extension(state[i], local_values[col_full_second_after_mds(r, i)]); - yield_constr.constraint_wrapping(builder, diff); - state[i] = local_values[col_full_second_after_mds(r, i)]; - } + let state0_cubed = builder.cube_extension(state[0]); + let diff = builder.sub_extension(state0_cubed, local_values[col_partial_mid_sbox(r)]); + yield_constr.constraint_wrapping(builder, diff); + let state0_cubed = local_values[col_partial_mid_sbox(r)]; + state[0] = builder.mul_many_extension(&[state[0], state0_cubed, state0_cubed]); // Form state ** 7. + let diff = builder.sub_extension(state[0], local_values[col_partial_after_sbox(r)]); + yield_constr.constraint_wrapping(builder, diff); + state[0] = local_values[col_partial_after_sbox(r)]; + + state = F::mds_layer_recursive(builder, &state); + } + + for r in 0..HALF_N_FULL_ROUNDS { + F::constant_layer_recursive( + builder, + &mut state, + HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r, + ); + + for i in 0..SPONGE_WIDTH { + let state_cubed = builder.cube_extension(state[i]); + let diff = + builder.sub_extension(state_cubed, local_values[col_full_second_mid_sbox(r, i)]); + yield_constr.constraint_wrapping(builder, diff); + let state_cubed = local_values[col_full_second_mid_sbox(r, i)]; + state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); + // Form state ** 7. + } + + state = F::mds_layer_recursive(builder, &state); + + for i in 0..SPONGE_WIDTH { + let diff = + builder.sub_extension(state[i], local_values[col_full_second_after_mds(r, i)]); + yield_constr.constraint_wrapping(builder, diff); + state[i] = local_values[col_full_second_after_mds(r, i)]; } } } @@ -269,11 +266,10 @@ mod tests { use starky::constraint_consumer::ConstraintConsumer; use starky::vars::StarkEvaluationVars; - use crate::permutation_unit::SPONGE_WIDTH; + use crate::permutation_unit::{eval_permutation_unit, generate_permutation_unit, SPONGE_WIDTH}; use crate::public_input_layout::NUM_PUBLIC_INPUTS; use crate::registers::permutation::{col_input, col_output}; use crate::registers::NUM_COLUMNS; - use crate::system_zero::SystemZero; #[test] fn generate_eval_consistency() { @@ -281,7 +277,7 @@ mod tests { type F = GoldilocksField; let mut values = [F::default(); NUM_COLUMNS]; - SystemZero::::generate_permutation_unit(&mut values); + generate_permutation_unit(&mut values); let vars = StarkEvaluationVars { local_values: &values, @@ -295,7 +291,7 @@ mod tests { GoldilocksField::ONE, GoldilocksField::ONE, ); - SystemZero::::eval_permutation_unit(vars, &mut constrant_consumer); + eval_permutation_unit(vars, &mut constrant_consumer); for &acc in &constrant_consumer.constraint_accs { assert_eq!(acc, GoldilocksField::ZERO); } @@ -318,7 +314,7 @@ mod tests { for i in 0..SPONGE_WIDTH { values[col_input(i)] = state[i]; } - SystemZero::::generate_permutation_unit(&mut values); + generate_permutation_unit(&mut values); let mut result = [F::default(); SPONGE_WIDTH]; for i in 0..SPONGE_WIDTH { result[i] = values[col_output(i)]; diff --git a/system_zero/src/system_zero.rs b/system_zero/src/system_zero.rs index 780b1d38..2eeb4697 100644 --- a/system_zero/src/system_zero.rs +++ b/system_zero/src/system_zero.rs @@ -12,7 +12,14 @@ use starky::vars::StarkEvaluationVars; use crate::arithmetic::{ eval_arithmetic_unit, eval_arithmetic_unit_recursively, generate_arithmetic_unit, }; +use crate::core_registers::{ + eval_core_registers, eval_core_registers_recursively, generate_first_row_core_registers, + generate_next_row_core_registers, +}; use crate::memory::TransactionMemory; +use crate::permutation_unit::{ + eval_permutation_unit, eval_permutation_unit_recursively, generate_permutation_unit, +}; use crate::public_input_layout::NUM_PUBLIC_INPUTS; use crate::registers::NUM_COLUMNS; @@ -29,16 +36,17 @@ impl, const D: usize> SystemZero { let memory = TransactionMemory::default(); let mut row = [F::ZERO; NUM_COLUMNS]; - self.generate_first_row_core_registers(&mut row); - Self::generate_permutation_unit(&mut row); + generate_first_row_core_registers(&mut row); + generate_arithmetic_unit(&mut row); + generate_permutation_unit(&mut row); let mut trace = Vec::with_capacity(MIN_TRACE_ROWS); loop { let mut next_row = [F::ZERO; NUM_COLUMNS]; - self.generate_next_row_core_registers(&row, &mut next_row); + generate_next_row_core_registers(&row, &mut next_row); generate_arithmetic_unit(&mut next_row); - Self::generate_permutation_unit(&mut next_row); + generate_permutation_unit(&mut next_row); trace.push(row); row = next_row; @@ -74,9 +82,9 @@ impl, const D: usize> Stark for SystemZero, P: PackedField, { - self.eval_core_registers(vars, yield_constr); + eval_core_registers(vars, yield_constr); eval_arithmetic_unit(vars, yield_constr); - Self::eval_permutation_unit(vars, yield_constr); + eval_permutation_unit::(vars, yield_constr); // TODO: Other units } @@ -86,9 +94,9 @@ impl, const D: usize> Stark for SystemZero, yield_constr: &mut RecursiveConstraintConsumer, ) { - self.eval_core_registers_recursively(builder, vars, yield_constr); + eval_core_registers_recursively(builder, vars, yield_constr); eval_arithmetic_unit_recursively(builder, vars, yield_constr); - Self::eval_permutation_unit_recursively(builder, vars, yield_constr); + eval_permutation_unit_recursively(builder, vars, yield_constr); // TODO: Other units }