mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-10 01:33:07 +00:00
Working recursively
This commit is contained in:
parent
c508fe4362
commit
5d7f4de2a6
@ -131,7 +131,6 @@ where
|
||||
|
||||
<F as Poseidon<WIDTH>>::partial_first_constant_layer(&mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init(&mut state);
|
||||
// for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
|
||||
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
|
||||
constraints.push(state[0] - sbox_in);
|
||||
@ -170,7 +169,78 @@ where
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
|
||||
todo!()
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
// Assert that `swap` is binary.
|
||||
let swap = vars.local_wires[Self::WIRE_SWAP];
|
||||
constraints.push(swap * (swap - F::ONE));
|
||||
|
||||
let mut state = Vec::with_capacity(12);
|
||||
for i in 0..4 {
|
||||
let a = vars.local_wires[i];
|
||||
let b = vars.local_wires[i + 4];
|
||||
state.push(a + swap * (b - a));
|
||||
}
|
||||
for i in 0..4 {
|
||||
let a = vars.local_wires[i + 4];
|
||||
let b = vars.local_wires[i];
|
||||
state.push(a + swap * (b - a));
|
||||
}
|
||||
for i in 8..12 {
|
||||
state.push(vars.local_wires[i]);
|
||||
}
|
||||
|
||||
let mut state: [F; WIDTH] = state.try_into().unwrap();
|
||||
let mut round_ctr = 0;
|
||||
|
||||
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
|
||||
<F as Poseidon<WIDTH>>::constant_layer(&mut state, round_ctr);
|
||||
for i in 0..WIDTH {
|
||||
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
|
||||
constraints.push(state[i] - sbox_in);
|
||||
state[i] = sbox_in;
|
||||
}
|
||||
<F as Poseidon<WIDTH>>::sbox_layer(&mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
|
||||
round_ctr += 1;
|
||||
}
|
||||
|
||||
<F as Poseidon<WIDTH>>::partial_first_constant_layer(&mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init(&mut state);
|
||||
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
|
||||
constraints.push(state[0] - sbox_in);
|
||||
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
|
||||
state[0] +=
|
||||
F::from_canonical_u64(<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r]);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(&state, r);
|
||||
}
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
|
||||
constraints.push(state[0] - sbox_in);
|
||||
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(
|
||||
&state,
|
||||
poseidon::N_PARTIAL_ROUNDS - 1,
|
||||
);
|
||||
round_ctr += poseidon::N_PARTIAL_ROUNDS;
|
||||
|
||||
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
|
||||
<F as Poseidon<WIDTH>>::constant_layer(&mut state, round_ctr);
|
||||
for i in 0..WIDTH {
|
||||
let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)];
|
||||
constraints.push(state[i] - sbox_in);
|
||||
state[i] = sbox_in;
|
||||
}
|
||||
<F as Poseidon<WIDTH>>::sbox_layer(&mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
|
||||
round_ctr += 1;
|
||||
}
|
||||
|
||||
for i in 0..WIDTH {
|
||||
constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]);
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_recursively(
|
||||
@ -178,7 +248,89 @@ where
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
vars: EvaluationTargets<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
todo!()
|
||||
let one = builder.one_extension();
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
// Assert that `swap` is binary.
|
||||
let swap = vars.local_wires[Self::WIRE_SWAP];
|
||||
constraints.push(builder.mul_sub_extension(swap, swap, swap));
|
||||
|
||||
let mut state = Vec::with_capacity(12);
|
||||
for i in 0..4 {
|
||||
let a = vars.local_wires[i];
|
||||
let b = vars.local_wires[i + 4];
|
||||
let delta = builder.sub_extension(b, a);
|
||||
state.push(builder.mul_add_extension(swap, delta, a));
|
||||
}
|
||||
for i in 0..4 {
|
||||
let a = vars.local_wires[i + 4];
|
||||
let b = vars.local_wires[i];
|
||||
let delta = builder.sub_extension(b, a);
|
||||
state.push(builder.mul_add_extension(swap, delta, a));
|
||||
}
|
||||
for i in 8..12 {
|
||||
state.push(vars.local_wires[i]);
|
||||
}
|
||||
|
||||
let mut state: [ExtensionTarget<D>; WIDTH] = state.try_into().unwrap();
|
||||
let mut round_ctr = 0;
|
||||
|
||||
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
|
||||
<F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
|
||||
for i in 0..WIDTH {
|
||||
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
|
||||
constraints.push(builder.sub_extension(state[i], sbox_in));
|
||||
state[i] = sbox_in;
|
||||
}
|
||||
<F as Poseidon<WIDTH>>::sbox_layer_recursive(builder, &mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_layer_recursive(builder, &state);
|
||||
round_ctr += 1;
|
||||
}
|
||||
|
||||
<F as Poseidon<WIDTH>>::partial_first_constant_layer_recursive(builder, &mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init_recursive(builder, &mut state);
|
||||
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
|
||||
constraints.push(builder.sub_extension(state[0], sbox_in));
|
||||
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
|
||||
state[0] = builder.arithmetic_extension(
|
||||
F::from_canonical_u64(<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r]),
|
||||
F::ONE,
|
||||
one,
|
||||
one,
|
||||
state[0],
|
||||
);
|
||||
state =
|
||||
<F as Poseidon<WIDTH>>::mds_partial_layer_fast_field_recursive(builder, &state, r);
|
||||
}
|
||||
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
|
||||
constraints.push(builder.sub_extension(state[0], sbox_in));
|
||||
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
|
||||
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field_recursive(
|
||||
builder,
|
||||
&state,
|
||||
poseidon::N_PARTIAL_ROUNDS - 1,
|
||||
);
|
||||
round_ctr += poseidon::N_PARTIAL_ROUNDS;
|
||||
|
||||
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
|
||||
<F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
|
||||
for i in 0..WIDTH {
|
||||
let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)];
|
||||
constraints.push(builder.sub_extension(state[i], sbox_in));
|
||||
state[i] = sbox_in;
|
||||
}
|
||||
<F as Poseidon<WIDTH>>::sbox_layer_recursive(builder, &mut state);
|
||||
state = <F as Poseidon<WIDTH>>::mds_layer_recursive(builder, &state);
|
||||
round_ctr += 1;
|
||||
}
|
||||
|
||||
for i in 0..WIDTH {
|
||||
constraints
|
||||
.push(builder.sub_extension(state[i], vars.local_wires[Self::wire_output(i)]));
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn generators(
|
||||
|
||||
@ -7,8 +7,10 @@ use std::convert::TryInto;
|
||||
use unroll::unroll_for_loops;
|
||||
|
||||
use crate::field::crandall_field::CrandallField;
|
||||
use crate::field::extension_field::FieldExtension;
|
||||
use crate::field::field_types::{Field, PrimeField};
|
||||
use crate::field::extension_field::target::ExtensionTarget;
|
||||
use crate::field::extension_field::{Extendable, FieldExtension};
|
||||
use crate::field::field_types::{Field, PrimeField, RichField};
|
||||
use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
|
||||
// The number of full rounds and partial rounds is given by the
|
||||
// calc_round_numbers.py script. They happen to be the same for both
|
||||
@ -192,6 +194,40 @@ where
|
||||
res
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[unroll_for_loops]
|
||||
fn mds_row_shf_recursive<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
r: usize,
|
||||
v: &[ExtensionTarget<D>; WIDTH],
|
||||
) -> ExtensionTarget<D> {
|
||||
let one = builder.one_extension();
|
||||
debug_assert!(r < WIDTH);
|
||||
// The values of MDS_MATRIX_EXPS are known to be small, so we can
|
||||
// accumulate all the products for each row and reduce just once
|
||||
// at the end (done by the caller).
|
||||
|
||||
// NB: Unrolling this, calculating each term independently, and
|
||||
// summing at the end, didn't improve performance for me.
|
||||
let mut res = builder.zero_extension();
|
||||
|
||||
// This is a hacky way of fully unrolling the loop.
|
||||
assert!(WIDTH <= 12);
|
||||
for i in 0..12 {
|
||||
if i < WIDTH {
|
||||
res = builder.arithmetic_extension(
|
||||
F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]),
|
||||
F::ONE,
|
||||
one,
|
||||
v[(i + r) % WIDTH],
|
||||
res,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[unroll_for_loops]
|
||||
fn mds_layer(state_: &[Self; WIDTH]) -> [Self; WIDTH] {
|
||||
@ -231,6 +267,25 @@ where
|
||||
result
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[unroll_for_loops]
|
||||
fn mds_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
state: &[ExtensionTarget<D>; WIDTH],
|
||||
) -> [ExtensionTarget<D>; WIDTH] {
|
||||
let mut result = [builder.zero_extension(); WIDTH];
|
||||
|
||||
// This is a hacky way of fully unrolling the loop.
|
||||
assert!(WIDTH <= 12);
|
||||
for r in 0..12 {
|
||||
if r < WIDTH {
|
||||
result[r] = Self::mds_row_shf_recursive(builder, r, state);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[unroll_for_loops]
|
||||
fn partial_first_constant_layer<F: FieldExtension<D, BaseField = Self>, const D: usize>(
|
||||
@ -244,6 +299,27 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[unroll_for_loops]
|
||||
fn partial_first_constant_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
state: &mut [ExtensionTarget<D>; WIDTH],
|
||||
) {
|
||||
let one = builder.one_extension();
|
||||
assert!(WIDTH <= 12);
|
||||
for i in 0..12 {
|
||||
if i < WIDTH {
|
||||
state[i] = builder.arithmetic_extension(
|
||||
F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]),
|
||||
F::ONE,
|
||||
one,
|
||||
one,
|
||||
state[i],
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[unroll_for_loops]
|
||||
fn mds_partial_layer_init<F: FieldExtension<D, BaseField = Self>, const D: usize>(
|
||||
@ -276,6 +352,41 @@ where
|
||||
result
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[unroll_for_loops]
|
||||
fn mds_partial_layer_init_recursive<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
state: &[ExtensionTarget<D>; WIDTH],
|
||||
) -> [ExtensionTarget<D>; WIDTH] {
|
||||
let one = builder.one_extension();
|
||||
let mut result = [builder.zero_extension(); WIDTH];
|
||||
|
||||
// Initial matrix has first row/column = [1, 0, ..., 0];
|
||||
|
||||
// c = 0
|
||||
result[0] = state[0];
|
||||
|
||||
assert!(WIDTH <= 12);
|
||||
for c in 1..12 {
|
||||
if c < WIDTH {
|
||||
assert!(WIDTH <= 12);
|
||||
for r in 1..12 {
|
||||
if r < WIDTH {
|
||||
// NB: FAST_PARTIAL_ROUND_INITIAL_MATRIX is stored in
|
||||
// column-major order so that this dot product is cache
|
||||
// friendly.
|
||||
let t = F::from_canonical_u64(
|
||||
Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1],
|
||||
);
|
||||
result[c] =
|
||||
builder.arithmetic_extension(t, F::ONE, one, state[r], result[c]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Computes s*A where s is the state row vector and A is the matrix
|
||||
///
|
||||
/// [ M_00 | v ]
|
||||
@ -343,6 +454,46 @@ where
|
||||
result
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[unroll_for_loops]
|
||||
fn mds_partial_layer_fast_field_recursive<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
state: &[ExtensionTarget<D>; WIDTH],
|
||||
r: usize,
|
||||
) -> [ExtensionTarget<D>; WIDTH] {
|
||||
let zero = builder.zero_extension();
|
||||
let one = builder.one_extension();
|
||||
|
||||
// Set d = [M_00 | w^] dot [state]
|
||||
let s0 = state[0];
|
||||
let mut d = builder.arithmetic_extension(
|
||||
F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]),
|
||||
F::ONE,
|
||||
one,
|
||||
s0,
|
||||
zero,
|
||||
);
|
||||
assert!(WIDTH <= 12);
|
||||
for i in 1..12 {
|
||||
if i < WIDTH {
|
||||
let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]);
|
||||
d = builder.arithmetic_extension(t, F::ONE, one, state[i], d);
|
||||
}
|
||||
}
|
||||
|
||||
// result = [d] concat [state[0] * v + state[shift up by 1]]
|
||||
let mut result = [zero; WIDTH];
|
||||
result[0] = d;
|
||||
assert!(WIDTH <= 12);
|
||||
for i in 1..12 {
|
||||
if i < WIDTH {
|
||||
let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]);
|
||||
result[i] = builder.arithmetic_extension(t, F::ONE, one, state[0], state[i]);
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[unroll_for_loops]
|
||||
fn constant_layer<F: FieldExtension<D, BaseField = Self>, const D: usize>(
|
||||
@ -357,6 +508,28 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[unroll_for_loops]
|
||||
fn constant_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
state: &mut [ExtensionTarget<D>; WIDTH],
|
||||
round_ctr: usize,
|
||||
) {
|
||||
let one = builder.one_extension();
|
||||
assert!(WIDTH <= 12);
|
||||
for i in 0..12 {
|
||||
if i < WIDTH {
|
||||
state[i] = builder.arithmetic_extension(
|
||||
F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]),
|
||||
F::ONE,
|
||||
one,
|
||||
one,
|
||||
state[i],
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn sbox_monomial<F: FieldExtension<D, BaseField = Self>, const D: usize>(x: F) -> F {
|
||||
// x |--> x^7
|
||||
@ -366,6 +539,18 @@ where
|
||||
x3 * x4
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn sbox_monomial_recursive<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
x: ExtensionTarget<D>,
|
||||
) -> ExtensionTarget<D> {
|
||||
// x |--> x^7
|
||||
let x2 = builder.mul_extension(x, x);
|
||||
let x4 = builder.mul_extension(x2, x2);
|
||||
let x3 = builder.mul_extension(x, x2);
|
||||
builder.mul_extension(x3, x4)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[unroll_for_loops]
|
||||
fn sbox_layer<F: FieldExtension<D, BaseField = Self>, const D: usize>(state: &mut [F; WIDTH]) {
|
||||
@ -377,6 +562,20 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[unroll_for_loops]
|
||||
fn sbox_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
state: &mut [ExtensionTarget<D>; WIDTH],
|
||||
) {
|
||||
assert!(WIDTH <= 12);
|
||||
for i in 0..12 {
|
||||
if i < WIDTH {
|
||||
state[i] = Self::sbox_monomial_recursive(builder, state[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn full_rounds(state: &mut [Self; WIDTH], round_ctr: &mut usize) {
|
||||
for _ in 0..HALF_N_FULL_ROUNDS {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user