diff --git a/starky/src/constraint_consumer.rs b/starky/src/constraint_consumer.rs index b7c9f399..c909b520 100644 --- a/starky/src/constraint_consumer.rs +++ b/starky/src/constraint_consumer.rs @@ -12,7 +12,9 @@ pub struct ConstraintConsumer { alphas: Vec, /// Running sums of constraints that have been emitted so far, scaled by powers of alpha. - constraint_accs: Vec

, + // TODO(JN): This is pub so it can be used in a test. Once we have an API for accessing this + // result, it should be made private. + pub constraint_accs: Vec

, /// The evaluation of `X - g^(n-1)`. z_last: P, diff --git a/system_zero/Cargo.toml b/system_zero/Cargo.toml index b908dea0..e5b617c9 100644 --- a/system_zero/Cargo.toml +++ b/system_zero/Cargo.toml @@ -10,3 +10,5 @@ starky = { path = "../starky" } anyhow = "1.0.40" env_logger = "0.9.0" log = "0.4.14" +rand = "0.8.4" +rand_chacha = "0.3.1" diff --git a/system_zero/src/column_layout.rs b/system_zero/src/column_layout.rs index 3d8fc2c0..7a9e92e5 100644 --- a/system_zero/src/column_layout.rs +++ b/system_zero/src/column_layout.rs @@ -24,35 +24,56 @@ pub(crate) const COL_STACK_PTR: usize = COL_FRAME_PTR + 1; const START_PERMUTATION_UNIT: usize = COL_STACK_PTR + 1; -pub(crate) const fn col_permutation_full_first(round: usize, i: usize) -> usize { +const START_PERMUTATION_FULL_FIRST: usize = START_PERMUTATION_UNIT + SPONGE_WIDTH; + +pub(crate) const fn col_permutation_full_first_mid_sbox(round: usize, i: usize) -> usize { debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); debug_assert!(i < SPONGE_WIDTH); - START_PERMUTATION_UNIT + round * SPONGE_WIDTH + i + START_PERMUTATION_FULL_FIRST + 2 * round * SPONGE_WIDTH + i +} + +pub(crate) const fn col_permutation_full_first_after_mds(round: usize, i: usize) -> usize { + debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < SPONGE_WIDTH); + START_PERMUTATION_FULL_FIRST + (2 * round + 1) * SPONGE_WIDTH + i } const START_PERMUTATION_PARTIAL: usize = - col_permutation_full_first(poseidon::HALF_N_FULL_ROUNDS - 1, SPONGE_WIDTH - 1) + 1; + col_permutation_full_first_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, SPONGE_WIDTH - 1) + 1; -pub(crate) const fn col_permutation_partial(round: usize) -> usize { +pub(crate) const fn col_permutation_partial_mid_sbox(round: usize) -> usize { debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); - START_PERMUTATION_PARTIAL + round + START_PERMUTATION_PARTIAL + 2 * round } -const START_PERMUTATION_FULL_SECOND: usize = COL_STACK_PTR + 1; +pub(crate) const fn col_permutation_partial_after_sbox(round: usize) -> usize { + debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); + START_PERMUTATION_PARTIAL + 2 * round + 1 +} -pub(crate) const fn col_permutation_full_second(round: usize, i: usize) -> usize { +const START_PERMUTATION_FULL_SECOND: usize = + col_permutation_partial_after_sbox(poseidon::N_PARTIAL_ROUNDS - 1) + 1; + +pub(crate) const fn col_permutation_full_second_mid_sbox(round: usize, i: usize) -> usize { debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS); debug_assert!(i < SPONGE_WIDTH); - START_PERMUTATION_FULL_SECOND + round * SPONGE_WIDTH + i + START_PERMUTATION_FULL_SECOND + 2 * round * SPONGE_WIDTH + i +} + +pub(crate) const fn col_permutation_full_second_after_mds(round: usize, i: usize) -> usize { + debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < SPONGE_WIDTH); + START_PERMUTATION_FULL_SECOND + (2 * round + 1) * SPONGE_WIDTH + i } pub(crate) const fn col_permutation_input(i: usize) -> usize { - col_permutation_full_first(0, i) + debug_assert!(i < SPONGE_WIDTH); + START_PERMUTATION_UNIT + i } pub(crate) const fn col_permutation_output(i: usize) -> usize { debug_assert!(i < SPONGE_WIDTH); - col_permutation_full_second(poseidon::HALF_N_FULL_ROUNDS, i) + col_permutation_full_second_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, i) } const END_PERMUTATION_UNIT: usize = col_permutation_output(SPONGE_WIDTH - 1); diff --git a/system_zero/src/permutation_unit.rs b/system_zero/src/permutation_unit.rs index 43883fca..7f12b9ce 100644 --- a/system_zero/src/permutation_unit.rs +++ b/system_zero/src/permutation_unit.rs @@ -2,36 +2,120 @@ use plonky2::field::extension_field::{Extendable, FieldExtension}; use plonky2::field::packed_field::PackedField; use plonky2::hash::hash_types::RichField; use plonky2::hash::hashing::SPONGE_WIDTH; +use plonky2::hash::poseidon::{HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS}; use plonky2::plonk::circuit_builder::CircuitBuilder; use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use starky::vars::StarkEvaluationTargets; use starky::vars::StarkEvaluationVars; -use crate::column_layout::{col_permutation_input, col_permutation_output, NUM_COLUMNS}; +use crate::column_layout::{ + col_permutation_full_first_after_mds as col_full_1st_after_mds, + col_permutation_full_first_mid_sbox as col_full_1st_mid_sbox, + col_permutation_full_second_after_mds as col_full_2nd_after_mds, + col_permutation_full_second_mid_sbox as col_full_2nd_mid_sbox, + col_permutation_input as col_input, + col_permutation_partial_after_sbox as col_partial_after_sbox, + col_permutation_partial_mid_sbox as col_partial_mid_sbox, NUM_COLUMNS, +}; use crate::public_input_layout::NUM_PUBLIC_INPUTS; use crate::system_zero::SystemZero; +fn constant_layer( + mut state: [P; SPONGE_WIDTH], + round: usize, +) -> [P; SPONGE_WIDTH] +where + F: RichField, + FE: FieldExtension, + P: PackedField, +{ + // One day I might actually vectorize this, but today is not that day. + for i in 0..P::WIDTH { + let mut unpacked_state = [P::Scalar::default(); SPONGE_WIDTH]; + for j in 0..SPONGE_WIDTH { + unpacked_state[j] = state[j].as_slice()[i]; + } + F::constant_layer_field(&mut unpacked_state, round); + for j in 0..SPONGE_WIDTH { + state[j].as_slice_mut()[i] = unpacked_state[j]; + } + } + state +} + +fn mds_layer(mut state: [P; SPONGE_WIDTH]) -> [P; SPONGE_WIDTH] +where + F: RichField, + FE: FieldExtension, + P: PackedField, +{ + for i in 0..P::WIDTH { + let mut unpacked_state = [P::Scalar::default(); SPONGE_WIDTH]; + for j in 0..SPONGE_WIDTH { + unpacked_state[j] = state[j].as_slice()[i]; + } + unpacked_state = F::mds_layer_field(&unpacked_state); + for j in 0..SPONGE_WIDTH { + state[j].as_slice_mut()[i] = unpacked_state[j]; + } + } + state +} + impl, const D: usize> SystemZero { - pub(crate) fn generate_permutation_unit(&self, values: &mut [F; NUM_COLUMNS]) { + pub(crate) fn generate_permutation_unit(values: &mut [F; NUM_COLUMNS]) { // Load inputs. let mut state = [F::ZERO; SPONGE_WIDTH]; for i in 0..SPONGE_WIDTH { - state[i] = values[col_permutation_input(i)]; + state[i] = values[col_input(i)]; } - // TODO: First full rounds. - // TODO: Partial rounds. - // TODO: Second full rounds. + for r in 0..HALF_N_FULL_ROUNDS { + F::constant_layer(&mut state, r); - // Write outputs. - for i in 0..SPONGE_WIDTH { - values[col_permutation_output(i)] = state[i]; + for i in 0..SPONGE_WIDTH { + let state_cubed = state[i].cube(); + values[col_full_1st_mid_sbox(r, i)] = state_cubed; + state[i] *= state_cubed.square(); // Form state ** 7. + } + + state = F::mds_layer(&state); + + for i in 0..SPONGE_WIDTH { + values[col_full_1st_after_mds(r, i)] = state[i]; + } + } + + for r in 0..N_PARTIAL_ROUNDS { + F::constant_layer(&mut state, HALF_N_FULL_ROUNDS + r); + + let state0_cubed = state[0].cube(); + values[col_partial_mid_sbox(r)] = state0_cubed; + state[0] *= state0_cubed.square(); // Form state ** 7. + values[col_partial_after_sbox(r)] = state[0]; + + state = F::mds_layer(&state); + } + + for r in 0..HALF_N_FULL_ROUNDS { + F::constant_layer(&mut state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r); + + for i in 0..SPONGE_WIDTH { + let state_cubed = state[i].cube(); + values[col_full_2nd_mid_sbox(r, i)] = state_cubed; + state[i] *= state_cubed.square(); // Form state ** 7. + } + + state = F::mds_layer(&state); + + for i in 0..SPONGE_WIDTH { + values[col_full_2nd_after_mds(r, i)] = state[i]; + } } } #[inline] pub(crate) fn eval_permutation_unit( - &self, vars: StarkEvaluationVars, yield_constr: &mut ConstraintConsumer

, ) where @@ -43,22 +127,64 @@ impl, const D: usize> SystemZero { // Load inputs. let mut state = [P::ZEROS; SPONGE_WIDTH]; for i in 0..SPONGE_WIDTH { - state[i] = local_values[col_permutation_input(i)]; + state[i] = local_values[col_input(i)]; } - // TODO: First full rounds. - // TODO: Partial rounds. - // TODO: Second full rounds. + for r in 0..HALF_N_FULL_ROUNDS { + state = constant_layer(state, r); - // Assert that the computed output matches the outputs in the trace. - for i in 0..SPONGE_WIDTH { - let out = local_values[col_permutation_output(i)]; - yield_constr.constraint(state[i] - out); + for i in 0..SPONGE_WIDTH { + let state_cubed = state[i] * state[i].square(); + yield_constr + .constraint_wrapping(state_cubed - local_values[col_full_1st_mid_sbox(r, i)]); + let state_cubed = local_values[col_full_1st_mid_sbox(r, i)]; + state[i] *= state_cubed.square(); // Form state ** 7. + } + + state = mds_layer(state); + + for i in 0..SPONGE_WIDTH { + yield_constr + .constraint_wrapping(state[i] - local_values[col_full_1st_after_mds(r, i)]); + state[i] = local_values[col_full_1st_after_mds(r, i)]; + } + } + + for r in 0..N_PARTIAL_ROUNDS { + state = constant_layer(state, HALF_N_FULL_ROUNDS + r); + + let state0_cubed = state[0] * state[0].square(); + yield_constr.constraint_wrapping(state0_cubed - local_values[col_partial_mid_sbox(r)]); + let state0_cubed = local_values[col_partial_mid_sbox(r)]; + state[0] *= state0_cubed.square(); // Form state ** 7. + yield_constr.constraint_wrapping(state[0] - local_values[col_partial_after_sbox(r)]); + state[0] = local_values[col_partial_after_sbox(r)]; + + state = mds_layer(state); + } + + for r in 0..HALF_N_FULL_ROUNDS { + state = constant_layer(state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r); + + for i in 0..SPONGE_WIDTH { + let state_cubed = state[i] * state[i].square(); + yield_constr + .constraint_wrapping(state_cubed - local_values[col_full_2nd_mid_sbox(r, i)]); + let state_cubed = local_values[col_full_2nd_mid_sbox(r, i)]; + state[i] *= state_cubed.square(); // Form state ** 7. + } + + state = mds_layer(state); + + for i in 0..SPONGE_WIDTH { + yield_constr + .constraint_wrapping(state[i] - local_values[col_full_2nd_after_mds(r, i)]); + state[i] = local_values[col_full_2nd_after_mds(r, i)]; + } } } pub(crate) fn eval_permutation_unit_recursively( - &self, builder: &mut CircuitBuilder, vars: StarkEvaluationTargets, yield_constr: &mut RecursiveConstraintConsumer, @@ -69,18 +195,145 @@ impl, const D: usize> SystemZero { // Load inputs. let mut state = [zero; SPONGE_WIDTH]; for i in 0..SPONGE_WIDTH { - state[i] = local_values[col_permutation_input(i)]; + state[i] = local_values[col_input(i)]; } - // TODO: First full rounds. - // TODO: Partial rounds. - // TODO: Second full rounds. + for r in 0..HALF_N_FULL_ROUNDS { + F::constant_layer_recursive(builder, &mut state, r); - // Assert that the computed output matches the outputs in the trace. - for i in 0..SPONGE_WIDTH { - let out = local_values[col_permutation_output(i)]; - let diff = builder.sub_extension(state[i], out); - yield_constr.constraint(builder, diff); + for i in 0..SPONGE_WIDTH { + let state_cubed = builder.cube_extension(state[i]); + let diff = + builder.sub_extension(state_cubed, local_values[col_full_1st_mid_sbox(r, i)]); + yield_constr.constraint_wrapping(builder, diff); + let state_cubed = local_values[col_full_1st_mid_sbox(r, i)]; + state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); + // Form state ** 7. + } + + state = F::mds_layer_recursive(builder, &state); + + for i in 0..SPONGE_WIDTH { + let diff = + builder.sub_extension(state[i], local_values[col_full_1st_after_mds(r, i)]); + yield_constr.constraint_wrapping(builder, diff); + state[i] = local_values[col_full_1st_after_mds(r, i)]; + } + } + + for r in 0..N_PARTIAL_ROUNDS { + F::constant_layer_recursive(builder, &mut state, HALF_N_FULL_ROUNDS + r); + + let state0_cubed = builder.cube_extension(state[0]); + let diff = builder.sub_extension(state0_cubed, local_values[col_partial_mid_sbox(r)]); + yield_constr.constraint_wrapping(builder, diff); + let state0_cubed = local_values[col_partial_mid_sbox(r)]; + state[0] = builder.mul_many_extension(&[state[0], state0_cubed, state0_cubed]); // Form state ** 7. + let diff = builder.sub_extension(state[0], local_values[col_partial_after_sbox(r)]); + yield_constr.constraint_wrapping(builder, diff); + state[0] = local_values[col_partial_after_sbox(r)]; + + state = F::mds_layer_recursive(builder, &state); + } + + for r in 0..HALF_N_FULL_ROUNDS { + F::constant_layer_recursive( + builder, + &mut state, + HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r, + ); + + for i in 0..SPONGE_WIDTH { + let state_cubed = builder.cube_extension(state[i]); + let diff = + builder.sub_extension(state_cubed, local_values[col_full_2nd_mid_sbox(r, i)]); + yield_constr.constraint_wrapping(builder, diff); + let state_cubed = local_values[col_full_2nd_mid_sbox(r, i)]; + state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); + // Form state ** 7. + } + + state = F::mds_layer_recursive(builder, &state); + + for i in 0..SPONGE_WIDTH { + let diff = + builder.sub_extension(state[i], local_values[col_full_2nd_after_mds(r, i)]); + yield_constr.constraint_wrapping(builder, diff); + state[i] = local_values[col_full_2nd_after_mds(r, i)]; + } } } } + +#[cfg(test)] +mod tests { + use plonky2::field::field_types::Field; + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::hash::poseidon::Poseidon; + use rand::SeedableRng; + use rand_chacha::ChaCha8Rng; + use starky::constraint_consumer::ConstraintConsumer; + use starky::vars::StarkEvaluationVars; + + use crate::column_layout::{ + col_permutation_input as col_input, col_permutation_output as col_output, NUM_COLUMNS, + }; + use crate::permutation_unit::SPONGE_WIDTH; + use crate::public_input_layout::NUM_PUBLIC_INPUTS; + use crate::system_zero::SystemZero; + + #[test] + fn generate_eval_consistency() { + const D: usize = 1; + type F = GoldilocksField; + + let mut values = [F::default(); NUM_COLUMNS]; + SystemZero::::generate_permutation_unit(&mut values); + + let vars = StarkEvaluationVars { + local_values: &values, + next_values: &[F::default(); NUM_COLUMNS], + public_inputs: &[F::default(); NUM_PUBLIC_INPUTS], + }; + + let mut constrant_consumer = ConstraintConsumer::new( + vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], + GoldilocksField::ONE, + GoldilocksField::ONE, + GoldilocksField::ONE, + ); + SystemZero::::eval_permutation_unit(vars, &mut constrant_consumer); + for &acc in &constrant_consumer.constraint_accs { + assert_eq!(acc, GoldilocksField::ZERO); + } + } + + #[test] + fn poseidon_result() { + const D: usize = 1; + type F = GoldilocksField; + + let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); + let state = [F::default(); SPONGE_WIDTH].map(|_| F::rand_from_rng(&mut rng)); + + // Get true Poseidon hash + let target = GoldilocksField::poseidon(state); + + // Get result from `generate_permutation_unit` + // Initialize `values` with randomness to test that the code doesn't rely on zero-filling. + let mut values = [F::default(); NUM_COLUMNS].map(|_| F::rand_from_rng(&mut rng)); + for i in 0..SPONGE_WIDTH { + values[col_input(i)] = state[i]; + } + SystemZero::::generate_permutation_unit(&mut values); + let mut result = [F::default(); SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { + result[i] = values[col_output(i)]; + } + + assert_eq!(target, result); + } + + // TODO(JN): test degree + // TODO(JN): test `eval_permutation_unit_recursively` +} diff --git a/system_zero/src/system_zero.rs b/system_zero/src/system_zero.rs index 9d78939c..70d3bbca 100644 --- a/system_zero/src/system_zero.rs +++ b/system_zero/src/system_zero.rs @@ -27,14 +27,14 @@ impl, const D: usize> SystemZero { let mut row = [F::ZERO; NUM_COLUMNS]; self.generate_first_row_core_registers(&mut row); - self.generate_permutation_unit(&mut row); + Self::generate_permutation_unit(&mut row); let mut trace = Vec::with_capacity(MIN_TRACE_ROWS); loop { let mut next_row = [F::ZERO; NUM_COLUMNS]; self.generate_next_row_core_registers(&row, &mut next_row); - self.generate_permutation_unit(&mut next_row); + Self::generate_permutation_unit(&mut next_row); trace.push(row); row = next_row; @@ -66,7 +66,7 @@ impl, const D: usize> Stark for SystemZero, { self.eval_core_registers(vars, yield_constr); - self.eval_permutation_unit(vars, yield_constr); + Self::eval_permutation_unit(vars, yield_constr); todo!() } @@ -77,7 +77,7 @@ impl, const D: usize> Stark for SystemZero, ) { self.eval_core_registers_recursively(builder, vars, yield_constr); - self.eval_permutation_unit_recursively(builder, vars, yield_constr); + Self::eval_permutation_unit_recursively(builder, vars, yield_constr); todo!() }