Working recursively

This commit is contained in:
wborgeaud 2021-09-16 19:17:37 +02:00
parent c508fe4362
commit 5d7f4de2a6
2 changed files with 356 additions and 5 deletions

View File

@ -131,7 +131,6 @@ where
<F as Poseidon<WIDTH>>::partial_first_constant_layer(&mut state);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init(&mut state);
// for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
constraints.push(state[0] - sbox_in);
@ -170,7 +169,78 @@ where
}
fn eval_unfiltered_base(&self, vars: EvaluationVarsBase<F>) -> Vec<F> {
todo!()
let mut constraints = Vec::with_capacity(self.num_constraints());
// Assert that `swap` is binary.
let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(swap * (swap - F::ONE));
let mut state = Vec::with_capacity(12);
for i in 0..4 {
let a = vars.local_wires[i];
let b = vars.local_wires[i + 4];
state.push(a + swap * (b - a));
}
for i in 0..4 {
let a = vars.local_wires[i + 4];
let b = vars.local_wires[i];
state.push(a + swap * (b - a));
}
for i in 8..12 {
state.push(vars.local_wires[i]);
}
let mut state: [F; WIDTH] = state.try_into().unwrap();
let mut round_ctr = 0;
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon<WIDTH>>::constant_layer(&mut state, round_ctr);
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
constraints.push(state[i] - sbox_in);
state[i] = sbox_in;
}
<F as Poseidon<WIDTH>>::sbox_layer(&mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
round_ctr += 1;
}
<F as Poseidon<WIDTH>>::partial_first_constant_layer(&mut state);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init(&mut state);
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
constraints.push(state[0] - sbox_in);
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
state[0] +=
F::from_canonical_u64(<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r]);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(&state, r);
}
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
constraints.push(state[0] - sbox_in);
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial(sbox_in);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field(
&state,
poseidon::N_PARTIAL_ROUNDS - 1,
);
round_ctr += poseidon::N_PARTIAL_ROUNDS;
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon<WIDTH>>::constant_layer(&mut state, round_ctr);
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)];
constraints.push(state[i] - sbox_in);
state[i] = sbox_in;
}
<F as Poseidon<WIDTH>>::sbox_layer(&mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_field(&state);
round_ctr += 1;
}
for i in 0..WIDTH {
constraints.push(state[i] - vars.local_wires[Self::wire_output(i)]);
}
constraints
}
fn eval_unfiltered_recursively(
@ -178,7 +248,89 @@ where
builder: &mut CircuitBuilder<F, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
todo!()
let one = builder.one_extension();
let mut constraints = Vec::with_capacity(self.num_constraints());
// Assert that `swap` is binary.
let swap = vars.local_wires[Self::WIRE_SWAP];
constraints.push(builder.mul_sub_extension(swap, swap, swap));
let mut state = Vec::with_capacity(12);
for i in 0..4 {
let a = vars.local_wires[i];
let b = vars.local_wires[i + 4];
let delta = builder.sub_extension(b, a);
state.push(builder.mul_add_extension(swap, delta, a));
}
for i in 0..4 {
let a = vars.local_wires[i + 4];
let b = vars.local_wires[i];
let delta = builder.sub_extension(b, a);
state.push(builder.mul_add_extension(swap, delta, a));
}
for i in 8..12 {
state.push(vars.local_wires[i]);
}
let mut state: [ExtensionTarget<D>; WIDTH] = state.try_into().unwrap();
let mut round_ctr = 0;
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_0(r, i)];
constraints.push(builder.sub_extension(state[i], sbox_in));
state[i] = sbox_in;
}
<F as Poseidon<WIDTH>>::sbox_layer_recursive(builder, &mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_recursive(builder, &state);
round_ctr += 1;
}
<F as Poseidon<WIDTH>>::partial_first_constant_layer_recursive(builder, &mut state);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_init_recursive(builder, &mut state);
for r in 0..(poseidon::N_PARTIAL_ROUNDS - 1) {
let sbox_in = vars.local_wires[Self::wire_partial_sbox(r)];
constraints.push(builder.sub_extension(state[0], sbox_in));
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
state[0] = builder.arithmetic_extension(
F::from_canonical_u64(<F as Poseidon<WIDTH>>::FAST_PARTIAL_ROUND_CONSTANTS[r]),
F::ONE,
one,
one,
state[0],
);
state =
<F as Poseidon<WIDTH>>::mds_partial_layer_fast_field_recursive(builder, &state, r);
}
let sbox_in = vars.local_wires[Self::wire_partial_sbox(poseidon::N_PARTIAL_ROUNDS - 1)];
constraints.push(builder.sub_extension(state[0], sbox_in));
state[0] = <F as Poseidon<WIDTH>>::sbox_monomial_recursive(builder, sbox_in);
state = <F as Poseidon<WIDTH>>::mds_partial_layer_fast_field_recursive(
builder,
&state,
poseidon::N_PARTIAL_ROUNDS - 1,
);
round_ctr += poseidon::N_PARTIAL_ROUNDS;
for r in 0..poseidon::HALF_N_FULL_ROUNDS {
<F as Poseidon<WIDTH>>::constant_layer_recursive(builder, &mut state, round_ctr);
for i in 0..WIDTH {
let sbox_in = vars.local_wires[Self::wire_full_sbox_1(r, i)];
constraints.push(builder.sub_extension(state[i], sbox_in));
state[i] = sbox_in;
}
<F as Poseidon<WIDTH>>::sbox_layer_recursive(builder, &mut state);
state = <F as Poseidon<WIDTH>>::mds_layer_recursive(builder, &state);
round_ctr += 1;
}
for i in 0..WIDTH {
constraints
.push(builder.sub_extension(state[i], vars.local_wires[Self::wire_output(i)]));
}
constraints
}
fn generators(

View File

@ -7,8 +7,10 @@ use std::convert::TryInto;
use unroll::unroll_for_loops;
use crate::field::crandall_field::CrandallField;
use crate::field::extension_field::FieldExtension;
use crate::field::field_types::{Field, PrimeField};
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field_types::{Field, PrimeField, RichField};
use crate::plonk::circuit_builder::CircuitBuilder;
// The number of full rounds and partial rounds is given by the
// calc_round_numbers.py script. They happen to be the same for both
@ -192,6 +194,40 @@ where
res
}
#[inline(always)]
#[unroll_for_loops]
fn mds_row_shf_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
r: usize,
v: &[ExtensionTarget<D>; WIDTH],
) -> ExtensionTarget<D> {
let one = builder.one_extension();
debug_assert!(r < WIDTH);
// The values of MDS_MATRIX_EXPS are known to be small, so we can
// accumulate all the products for each row and reduce just once
// at the end (done by the caller).
// NB: Unrolling this, calculating each term independently, and
// summing at the end, didn't improve performance for me.
let mut res = builder.zero_extension();
// This is a hacky way of fully unrolling the loop.
assert!(WIDTH <= 12);
for i in 0..12 {
if i < WIDTH {
res = builder.arithmetic_extension(
F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[i]),
F::ONE,
one,
v[(i + r) % WIDTH],
res,
);
}
}
res
}
#[inline(always)]
#[unroll_for_loops]
fn mds_layer(state_: &[Self; WIDTH]) -> [Self; WIDTH] {
@ -231,6 +267,25 @@ where
result
}
#[inline(always)]
#[unroll_for_loops]
fn mds_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &[ExtensionTarget<D>; WIDTH],
) -> [ExtensionTarget<D>; WIDTH] {
let mut result = [builder.zero_extension(); WIDTH];
// This is a hacky way of fully unrolling the loop.
assert!(WIDTH <= 12);
for r in 0..12 {
if r < WIDTH {
result[r] = Self::mds_row_shf_recursive(builder, r, state);
}
}
result
}
#[inline(always)]
#[unroll_for_loops]
fn partial_first_constant_layer<F: FieldExtension<D, BaseField = Self>, const D: usize>(
@ -244,6 +299,27 @@ where
}
}
#[inline(always)]
#[unroll_for_loops]
fn partial_first_constant_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &mut [ExtensionTarget<D>; WIDTH],
) {
let one = builder.one_extension();
assert!(WIDTH <= 12);
for i in 0..12 {
if i < WIDTH {
state[i] = builder.arithmetic_extension(
F::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]),
F::ONE,
one,
one,
state[i],
);
}
}
}
#[inline(always)]
#[unroll_for_loops]
fn mds_partial_layer_init<F: FieldExtension<D, BaseField = Self>, const D: usize>(
@ -276,6 +352,41 @@ where
result
}
#[inline(always)]
#[unroll_for_loops]
fn mds_partial_layer_init_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &[ExtensionTarget<D>; WIDTH],
) -> [ExtensionTarget<D>; WIDTH] {
let one = builder.one_extension();
let mut result = [builder.zero_extension(); WIDTH];
// Initial matrix has first row/column = [1, 0, ..., 0];
// c = 0
result[0] = state[0];
assert!(WIDTH <= 12);
for c in 1..12 {
if c < WIDTH {
assert!(WIDTH <= 12);
for r in 1..12 {
if r < WIDTH {
// NB: FAST_PARTIAL_ROUND_INITIAL_MATRIX is stored in
// column-major order so that this dot product is cache
// friendly.
let t = F::from_canonical_u64(
Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1],
);
result[c] =
builder.arithmetic_extension(t, F::ONE, one, state[r], result[c]);
}
}
}
}
result
}
/// Computes s*A where s is the state row vector and A is the matrix
///
/// [ M_00 | v ]
@ -343,6 +454,46 @@ where
result
}
#[inline(always)]
#[unroll_for_loops]
fn mds_partial_layer_fast_field_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &[ExtensionTarget<D>; WIDTH],
r: usize,
) -> [ExtensionTarget<D>; WIDTH] {
let zero = builder.zero_extension();
let one = builder.one_extension();
// Set d = [M_00 | w^] dot [state]
let s0 = state[0];
let mut d = builder.arithmetic_extension(
F::from_canonical_u64(1 << Self::MDS_MATRIX_EXPS[0]),
F::ONE,
one,
s0,
zero,
);
assert!(WIDTH <= 12);
for i in 1..12 {
if i < WIDTH {
let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]);
d = builder.arithmetic_extension(t, F::ONE, one, state[i], d);
}
}
// result = [d] concat [state[0] * v + state[shift up by 1]]
let mut result = [zero; WIDTH];
result[0] = d;
assert!(WIDTH <= 12);
for i in 1..12 {
if i < WIDTH {
let t = F::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]);
result[i] = builder.arithmetic_extension(t, F::ONE, one, state[0], state[i]);
}
}
result
}
#[inline(always)]
#[unroll_for_loops]
fn constant_layer<F: FieldExtension<D, BaseField = Self>, const D: usize>(
@ -357,6 +508,28 @@ where
}
}
#[inline(always)]
#[unroll_for_loops]
fn constant_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &mut [ExtensionTarget<D>; WIDTH],
round_ctr: usize,
) {
let one = builder.one_extension();
assert!(WIDTH <= 12);
for i in 0..12 {
if i < WIDTH {
state[i] = builder.arithmetic_extension(
F::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]),
F::ONE,
one,
one,
state[i],
);
}
}
}
#[inline(always)]
fn sbox_monomial<F: FieldExtension<D, BaseField = Self>, const D: usize>(x: F) -> F {
// x |--> x^7
@ -366,6 +539,18 @@ where
x3 * x4
}
#[inline(always)]
fn sbox_monomial_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
x: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
// x |--> x^7
let x2 = builder.mul_extension(x, x);
let x4 = builder.mul_extension(x2, x2);
let x3 = builder.mul_extension(x, x2);
builder.mul_extension(x3, x4)
}
#[inline(always)]
#[unroll_for_loops]
fn sbox_layer<F: FieldExtension<D, BaseField = Self>, const D: usize>(state: &mut [F; WIDTH]) {
@ -377,6 +562,20 @@ where
}
}
#[inline(always)]
#[unroll_for_loops]
fn sbox_layer_recursive<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
state: &mut [ExtensionTarget<D>; WIDTH],
) {
assert!(WIDTH <= 12);
for i in 0..12 {
if i < WIDTH {
state[i] = Self::sbox_monomial_recursive(builder, state[i]);
}
}
}
#[inline]
fn full_rounds(state: &mut [Self; WIDTH], round_ctr: &mut usize) {
for _ in 0..HALF_N_FULL_ROUNDS {