Implement Poseidon in system_zero/permutation_unit (#459)

* Implement Poseidon in system_zero/permutation_unit

* Minor cleanup

* Daniel PR comments

* Update dependencies
This commit is contained in:
Jakub Nabaglo 2022-02-04 16:50:57 -08:00 committed by GitHub
parent b6a60e721d
commit 83a572717e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 321 additions and 43 deletions

View File

@ -12,7 +12,9 @@ pub struct ConstraintConsumer<P: PackedField> {
alphas: Vec<P::Scalar>,
/// Running sums of constraints that have been emitted so far, scaled by powers of alpha.
constraint_accs: Vec<P>,
// TODO(JN): This is pub so it can be used in a test. Once we have an API for accessing this
// result, it should be made private.
pub constraint_accs: Vec<P>,
/// The evaluation of `X - g^(n-1)`.
z_last: P,

View File

@ -10,3 +10,5 @@ starky = { path = "../starky" }
anyhow = "1.0.40"
env_logger = "0.9.0"
log = "0.4.14"
rand = "0.8.4"
rand_chacha = "0.3.1"

View File

@ -24,35 +24,56 @@ pub(crate) const COL_STACK_PTR: usize = COL_FRAME_PTR + 1;
const START_PERMUTATION_UNIT: usize = COL_STACK_PTR + 1;
pub(crate) const fn col_permutation_full_first(round: usize, i: usize) -> usize {
const START_PERMUTATION_FULL_FIRST: usize = START_PERMUTATION_UNIT + SPONGE_WIDTH;
pub(crate) const fn col_permutation_full_first_mid_sbox(round: usize, i: usize) -> usize {
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH);
START_PERMUTATION_UNIT + round * SPONGE_WIDTH + i
START_PERMUTATION_FULL_FIRST + 2 * round * SPONGE_WIDTH + i
}
pub(crate) const fn col_permutation_full_first_after_mds(round: usize, i: usize) -> usize {
debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH);
START_PERMUTATION_FULL_FIRST + (2 * round + 1) * SPONGE_WIDTH + i
}
const START_PERMUTATION_PARTIAL: usize =
col_permutation_full_first(poseidon::HALF_N_FULL_ROUNDS - 1, SPONGE_WIDTH - 1) + 1;
col_permutation_full_first_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, SPONGE_WIDTH - 1) + 1;
pub(crate) const fn col_permutation_partial(round: usize) -> usize {
pub(crate) const fn col_permutation_partial_mid_sbox(round: usize) -> usize {
debug_assert!(round < poseidon::N_PARTIAL_ROUNDS);
START_PERMUTATION_PARTIAL + round
START_PERMUTATION_PARTIAL + 2 * round
}
const START_PERMUTATION_FULL_SECOND: usize = COL_STACK_PTR + 1;
pub(crate) const fn col_permutation_partial_after_sbox(round: usize) -> usize {
debug_assert!(round < poseidon::N_PARTIAL_ROUNDS);
START_PERMUTATION_PARTIAL + 2 * round + 1
}
pub(crate) const fn col_permutation_full_second(round: usize, i: usize) -> usize {
const START_PERMUTATION_FULL_SECOND: usize =
col_permutation_partial_after_sbox(poseidon::N_PARTIAL_ROUNDS - 1) + 1;
pub(crate) const fn col_permutation_full_second_mid_sbox(round: usize, i: usize) -> usize {
debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH);
START_PERMUTATION_FULL_SECOND + round * SPONGE_WIDTH + i
START_PERMUTATION_FULL_SECOND + 2 * round * SPONGE_WIDTH + i
}
pub(crate) const fn col_permutation_full_second_after_mds(round: usize, i: usize) -> usize {
debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS);
debug_assert!(i < SPONGE_WIDTH);
START_PERMUTATION_FULL_SECOND + (2 * round + 1) * SPONGE_WIDTH + i
}
pub(crate) const fn col_permutation_input(i: usize) -> usize {
col_permutation_full_first(0, i)
debug_assert!(i < SPONGE_WIDTH);
START_PERMUTATION_UNIT + i
}
pub(crate) const fn col_permutation_output(i: usize) -> usize {
debug_assert!(i < SPONGE_WIDTH);
col_permutation_full_second(poseidon::HALF_N_FULL_ROUNDS, i)
col_permutation_full_second_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, i)
}
const END_PERMUTATION_UNIT: usize = col_permutation_output(SPONGE_WIDTH - 1);

View File

@ -2,36 +2,120 @@ use plonky2::field::extension_field::{Extendable, FieldExtension};
use plonky2::field::packed_field::PackedField;
use plonky2::hash::hash_types::RichField;
use plonky2::hash::hashing::SPONGE_WIDTH;
use plonky2::hash::poseidon::{HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use starky::vars::StarkEvaluationTargets;
use starky::vars::StarkEvaluationVars;
use crate::column_layout::{col_permutation_input, col_permutation_output, NUM_COLUMNS};
use crate::column_layout::{
col_permutation_full_first_after_mds as col_full_1st_after_mds,
col_permutation_full_first_mid_sbox as col_full_1st_mid_sbox,
col_permutation_full_second_after_mds as col_full_2nd_after_mds,
col_permutation_full_second_mid_sbox as col_full_2nd_mid_sbox,
col_permutation_input as col_input,
col_permutation_partial_after_sbox as col_partial_after_sbox,
col_permutation_partial_mid_sbox as col_partial_mid_sbox, NUM_COLUMNS,
};
use crate::public_input_layout::NUM_PUBLIC_INPUTS;
use crate::system_zero::SystemZero;
fn constant_layer<F, FE, P, const D2: usize>(
mut state: [P; SPONGE_WIDTH],
round: usize,
) -> [P; SPONGE_WIDTH]
where
F: RichField,
FE: FieldExtension<D2, BaseField = F>,
P: PackedField<Scalar = FE>,
{
// One day I might actually vectorize this, but today is not that day.
for i in 0..P::WIDTH {
let mut unpacked_state = [P::Scalar::default(); SPONGE_WIDTH];
for j in 0..SPONGE_WIDTH {
unpacked_state[j] = state[j].as_slice()[i];
}
F::constant_layer_field(&mut unpacked_state, round);
for j in 0..SPONGE_WIDTH {
state[j].as_slice_mut()[i] = unpacked_state[j];
}
}
state
}
fn mds_layer<F, FE, P, const D2: usize>(mut state: [P; SPONGE_WIDTH]) -> [P; SPONGE_WIDTH]
where
F: RichField,
FE: FieldExtension<D2, BaseField = F>,
P: PackedField<Scalar = FE>,
{
for i in 0..P::WIDTH {
let mut unpacked_state = [P::Scalar::default(); SPONGE_WIDTH];
for j in 0..SPONGE_WIDTH {
unpacked_state[j] = state[j].as_slice()[i];
}
unpacked_state = F::mds_layer_field(&unpacked_state);
for j in 0..SPONGE_WIDTH {
state[j].as_slice_mut()[i] = unpacked_state[j];
}
}
state
}
impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
pub(crate) fn generate_permutation_unit(&self, values: &mut [F; NUM_COLUMNS]) {
pub(crate) fn generate_permutation_unit(values: &mut [F; NUM_COLUMNS]) {
// Load inputs.
let mut state = [F::ZERO; SPONGE_WIDTH];
for i in 0..SPONGE_WIDTH {
state[i] = values[col_permutation_input(i)];
state[i] = values[col_input(i)];
}
// TODO: First full rounds.
// TODO: Partial rounds.
// TODO: Second full rounds.
for r in 0..HALF_N_FULL_ROUNDS {
F::constant_layer(&mut state, r);
// Write outputs.
for i in 0..SPONGE_WIDTH {
values[col_permutation_output(i)] = state[i];
for i in 0..SPONGE_WIDTH {
let state_cubed = state[i].cube();
values[col_full_1st_mid_sbox(r, i)] = state_cubed;
state[i] *= state_cubed.square(); // Form state ** 7.
}
state = F::mds_layer(&state);
for i in 0..SPONGE_WIDTH {
values[col_full_1st_after_mds(r, i)] = state[i];
}
}
for r in 0..N_PARTIAL_ROUNDS {
F::constant_layer(&mut state, HALF_N_FULL_ROUNDS + r);
let state0_cubed = state[0].cube();
values[col_partial_mid_sbox(r)] = state0_cubed;
state[0] *= state0_cubed.square(); // Form state ** 7.
values[col_partial_after_sbox(r)] = state[0];
state = F::mds_layer(&state);
}
for r in 0..HALF_N_FULL_ROUNDS {
F::constant_layer(&mut state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r);
for i in 0..SPONGE_WIDTH {
let state_cubed = state[i].cube();
values[col_full_2nd_mid_sbox(r, i)] = state_cubed;
state[i] *= state_cubed.square(); // Form state ** 7.
}
state = F::mds_layer(&state);
for i in 0..SPONGE_WIDTH {
values[col_full_2nd_after_mds(r, i)] = state[i];
}
}
}
#[inline]
pub(crate) fn eval_permutation_unit<FE, P, const D2: usize>(
&self,
vars: StarkEvaluationVars<FE, P, NUM_COLUMNS, NUM_PUBLIC_INPUTS>,
yield_constr: &mut ConstraintConsumer<P>,
) where
@ -43,22 +127,64 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
// Load inputs.
let mut state = [P::ZEROS; SPONGE_WIDTH];
for i in 0..SPONGE_WIDTH {
state[i] = local_values[col_permutation_input(i)];
state[i] = local_values[col_input(i)];
}
// TODO: First full rounds.
// TODO: Partial rounds.
// TODO: Second full rounds.
for r in 0..HALF_N_FULL_ROUNDS {
state = constant_layer(state, r);
// Assert that the computed output matches the outputs in the trace.
for i in 0..SPONGE_WIDTH {
let out = local_values[col_permutation_output(i)];
yield_constr.constraint(state[i] - out);
for i in 0..SPONGE_WIDTH {
let state_cubed = state[i] * state[i].square();
yield_constr
.constraint_wrapping(state_cubed - local_values[col_full_1st_mid_sbox(r, i)]);
let state_cubed = local_values[col_full_1st_mid_sbox(r, i)];
state[i] *= state_cubed.square(); // Form state ** 7.
}
state = mds_layer(state);
for i in 0..SPONGE_WIDTH {
yield_constr
.constraint_wrapping(state[i] - local_values[col_full_1st_after_mds(r, i)]);
state[i] = local_values[col_full_1st_after_mds(r, i)];
}
}
for r in 0..N_PARTIAL_ROUNDS {
state = constant_layer(state, HALF_N_FULL_ROUNDS + r);
let state0_cubed = state[0] * state[0].square();
yield_constr.constraint_wrapping(state0_cubed - local_values[col_partial_mid_sbox(r)]);
let state0_cubed = local_values[col_partial_mid_sbox(r)];
state[0] *= state0_cubed.square(); // Form state ** 7.
yield_constr.constraint_wrapping(state[0] - local_values[col_partial_after_sbox(r)]);
state[0] = local_values[col_partial_after_sbox(r)];
state = mds_layer(state);
}
for r in 0..HALF_N_FULL_ROUNDS {
state = constant_layer(state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r);
for i in 0..SPONGE_WIDTH {
let state_cubed = state[i] * state[i].square();
yield_constr
.constraint_wrapping(state_cubed - local_values[col_full_2nd_mid_sbox(r, i)]);
let state_cubed = local_values[col_full_2nd_mid_sbox(r, i)];
state[i] *= state_cubed.square(); // Form state ** 7.
}
state = mds_layer(state);
for i in 0..SPONGE_WIDTH {
yield_constr
.constraint_wrapping(state[i] - local_values[col_full_2nd_after_mds(r, i)]);
state[i] = local_values[col_full_2nd_after_mds(r, i)];
}
}
}
pub(crate) fn eval_permutation_unit_recursively(
&self,
builder: &mut CircuitBuilder<F, D>,
vars: StarkEvaluationTargets<D, NUM_COLUMNS, NUM_PUBLIC_INPUTS>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
@ -69,18 +195,145 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
// Load inputs.
let mut state = [zero; SPONGE_WIDTH];
for i in 0..SPONGE_WIDTH {
state[i] = local_values[col_permutation_input(i)];
state[i] = local_values[col_input(i)];
}
// TODO: First full rounds.
// TODO: Partial rounds.
// TODO: Second full rounds.
for r in 0..HALF_N_FULL_ROUNDS {
F::constant_layer_recursive(builder, &mut state, r);
// Assert that the computed output matches the outputs in the trace.
for i in 0..SPONGE_WIDTH {
let out = local_values[col_permutation_output(i)];
let diff = builder.sub_extension(state[i], out);
yield_constr.constraint(builder, diff);
for i in 0..SPONGE_WIDTH {
let state_cubed = builder.cube_extension(state[i]);
let diff =
builder.sub_extension(state_cubed, local_values[col_full_1st_mid_sbox(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
let state_cubed = local_values[col_full_1st_mid_sbox(r, i)];
state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]);
// Form state ** 7.
}
state = F::mds_layer_recursive(builder, &state);
for i in 0..SPONGE_WIDTH {
let diff =
builder.sub_extension(state[i], local_values[col_full_1st_after_mds(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
state[i] = local_values[col_full_1st_after_mds(r, i)];
}
}
for r in 0..N_PARTIAL_ROUNDS {
F::constant_layer_recursive(builder, &mut state, HALF_N_FULL_ROUNDS + r);
let state0_cubed = builder.cube_extension(state[0]);
let diff = builder.sub_extension(state0_cubed, local_values[col_partial_mid_sbox(r)]);
yield_constr.constraint_wrapping(builder, diff);
let state0_cubed = local_values[col_partial_mid_sbox(r)];
state[0] = builder.mul_many_extension(&[state[0], state0_cubed, state0_cubed]); // Form state ** 7.
let diff = builder.sub_extension(state[0], local_values[col_partial_after_sbox(r)]);
yield_constr.constraint_wrapping(builder, diff);
state[0] = local_values[col_partial_after_sbox(r)];
state = F::mds_layer_recursive(builder, &state);
}
for r in 0..HALF_N_FULL_ROUNDS {
F::constant_layer_recursive(
builder,
&mut state,
HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r,
);
for i in 0..SPONGE_WIDTH {
let state_cubed = builder.cube_extension(state[i]);
let diff =
builder.sub_extension(state_cubed, local_values[col_full_2nd_mid_sbox(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
let state_cubed = local_values[col_full_2nd_mid_sbox(r, i)];
state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]);
// Form state ** 7.
}
state = F::mds_layer_recursive(builder, &state);
for i in 0..SPONGE_WIDTH {
let diff =
builder.sub_extension(state[i], local_values[col_full_2nd_after_mds(r, i)]);
yield_constr.constraint_wrapping(builder, diff);
state[i] = local_values[col_full_2nd_after_mds(r, i)];
}
}
}
}
#[cfg(test)]
mod tests {
use plonky2::field::field_types::Field;
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::hash::poseidon::Poseidon;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use starky::constraint_consumer::ConstraintConsumer;
use starky::vars::StarkEvaluationVars;
use crate::column_layout::{
col_permutation_input as col_input, col_permutation_output as col_output, NUM_COLUMNS,
};
use crate::permutation_unit::SPONGE_WIDTH;
use crate::public_input_layout::NUM_PUBLIC_INPUTS;
use crate::system_zero::SystemZero;
#[test]
fn generate_eval_consistency() {
const D: usize = 1;
type F = GoldilocksField;
let mut values = [F::default(); NUM_COLUMNS];
SystemZero::<F, D>::generate_permutation_unit(&mut values);
let vars = StarkEvaluationVars {
local_values: &values,
next_values: &[F::default(); NUM_COLUMNS],
public_inputs: &[F::default(); NUM_PUBLIC_INPUTS],
};
let mut constrant_consumer = ConstraintConsumer::new(
vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)],
GoldilocksField::ONE,
GoldilocksField::ONE,
GoldilocksField::ONE,
);
SystemZero::<F, D>::eval_permutation_unit(vars, &mut constrant_consumer);
for &acc in &constrant_consumer.constraint_accs {
assert_eq!(acc, GoldilocksField::ZERO);
}
}
#[test]
fn poseidon_result() {
const D: usize = 1;
type F = GoldilocksField;
let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
let state = [F::default(); SPONGE_WIDTH].map(|_| F::rand_from_rng(&mut rng));
// Get true Poseidon hash
let target = GoldilocksField::poseidon(state);
// Get result from `generate_permutation_unit`
// Initialize `values` with randomness to test that the code doesn't rely on zero-filling.
let mut values = [F::default(); NUM_COLUMNS].map(|_| F::rand_from_rng(&mut rng));
for i in 0..SPONGE_WIDTH {
values[col_input(i)] = state[i];
}
SystemZero::<F, D>::generate_permutation_unit(&mut values);
let mut result = [F::default(); SPONGE_WIDTH];
for i in 0..SPONGE_WIDTH {
result[i] = values[col_output(i)];
}
assert_eq!(target, result);
}
// TODO(JN): test degree
// TODO(JN): test `eval_permutation_unit_recursively`
}

View File

@ -27,14 +27,14 @@ impl<F: RichField + Extendable<D>, const D: usize> SystemZero<F, D> {
let mut row = [F::ZERO; NUM_COLUMNS];
self.generate_first_row_core_registers(&mut row);
self.generate_permutation_unit(&mut row);
Self::generate_permutation_unit(&mut row);
let mut trace = Vec::with_capacity(MIN_TRACE_ROWS);
loop {
let mut next_row = [F::ZERO; NUM_COLUMNS];
self.generate_next_row_core_registers(&row, &mut next_row);
self.generate_permutation_unit(&mut next_row);
Self::generate_permutation_unit(&mut next_row);
trace.push(row);
row = next_row;
@ -66,7 +66,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for SystemZero<F,
P: PackedField<Scalar = FE>,
{
self.eval_core_registers(vars, yield_constr);
self.eval_permutation_unit(vars, yield_constr);
Self::eval_permutation_unit(vars, yield_constr);
todo!()
}
@ -77,7 +77,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for SystemZero<F,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
self.eval_core_registers_recursively(builder, vars, yield_constr);
self.eval_permutation_unit_recursively(builder, vars, yield_constr);
Self::eval_permutation_unit_recursively(builder, vars, yield_constr);
todo!()
}