From efb1365021c3e8d784121b8501b57bd22d72c5d6 Mon Sep 17 00:00:00 2001 From: Jakub Nabaglo Date: Mon, 7 Feb 2022 14:29:31 -0800 Subject: [PATCH 01/15] Split `system_zero::column_layout` into submodules (#475) --- system_zero/src/column_layout.rs | 149 ++++++++++++++-------------- system_zero/src/permutation_unit.rs | 62 ++++++------ 2 files changed, 104 insertions(+), 107 deletions(-) diff --git a/system_zero/src/column_layout.rs b/system_zero/src/column_layout.rs index 7a9e92e5..fa5d627a 100644 --- a/system_zero/src/column_layout.rs +++ b/system_zero/src/column_layout.rs @@ -1,6 +1,3 @@ -use plonky2::hash::hashing::SPONGE_WIDTH; -use plonky2::hash::poseidon; - //// CORE REGISTERS /// A cycle counter. Starts at 0; increments by 1. @@ -21,87 +18,91 @@ pub(crate) const COL_FRAME_PTR: usize = COL_INSTRUCTION_PTR + 1; pub(crate) const COL_STACK_PTR: usize = COL_FRAME_PTR + 1; //// PERMUTATION UNIT +pub(crate) mod permutation { + use plonky2::hash::hashing::SPONGE_WIDTH; + use plonky2::hash::poseidon; -const START_PERMUTATION_UNIT: usize = COL_STACK_PTR + 1; + const START_UNIT: usize = super::COL_STACK_PTR + 1; -const START_PERMUTATION_FULL_FIRST: usize = START_PERMUTATION_UNIT + SPONGE_WIDTH; + const START_FULL_FIRST: usize = START_UNIT + SPONGE_WIDTH; -pub(crate) const fn col_permutation_full_first_mid_sbox(round: usize, i: usize) -> usize { - debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); - debug_assert!(i < SPONGE_WIDTH); - START_PERMUTATION_FULL_FIRST + 2 * round * SPONGE_WIDTH + i + pub const fn col_full_first_mid_sbox(round: usize, i: usize) -> usize { + debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < SPONGE_WIDTH); + START_FULL_FIRST + 2 * round * SPONGE_WIDTH + i + } + + pub const fn col_full_first_after_mds(round: usize, i: usize) -> usize { + debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < SPONGE_WIDTH); + START_FULL_FIRST + (2 * round + 1) * SPONGE_WIDTH + i + } + + const START_PARTIAL: usize = + col_full_first_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, SPONGE_WIDTH - 1) + 1; + + pub const fn col_partial_mid_sbox(round: usize) -> usize { + debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); + START_PARTIAL + 2 * round + } + + pub const fn col_partial_after_sbox(round: usize) -> usize { + debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); + START_PARTIAL + 2 * round + 1 + } + + const START_FULL_SECOND: usize = col_partial_after_sbox(poseidon::N_PARTIAL_ROUNDS - 1) + 1; + + pub const fn col_full_second_mid_sbox(round: usize, i: usize) -> usize { + debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < SPONGE_WIDTH); + START_FULL_SECOND + 2 * round * SPONGE_WIDTH + i + } + + pub const fn col_full_second_after_mds(round: usize, i: usize) -> usize { + debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < SPONGE_WIDTH); + START_FULL_SECOND + (2 * round + 1) * SPONGE_WIDTH + i + } + + pub const fn col_input(i: usize) -> usize { + debug_assert!(i < SPONGE_WIDTH); + START_UNIT + i + } + + pub const fn col_output(i: usize) -> usize { + debug_assert!(i < SPONGE_WIDTH); + col_full_second_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, i) + } + + pub(super) const END_UNIT: usize = col_output(SPONGE_WIDTH - 1); } -pub(crate) const fn col_permutation_full_first_after_mds(round: usize, i: usize) -> usize { - debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); - debug_assert!(i < SPONGE_WIDTH); - START_PERMUTATION_FULL_FIRST + (2 * round + 1) * SPONGE_WIDTH + i -} - -const START_PERMUTATION_PARTIAL: usize = - col_permutation_full_first_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, SPONGE_WIDTH - 1) + 1; - -pub(crate) const fn col_permutation_partial_mid_sbox(round: usize) -> usize { - debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); - START_PERMUTATION_PARTIAL + 2 * round -} - -pub(crate) const fn col_permutation_partial_after_sbox(round: usize) -> usize { - debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); - START_PERMUTATION_PARTIAL + 2 * round + 1 -} - -const START_PERMUTATION_FULL_SECOND: usize = - col_permutation_partial_after_sbox(poseidon::N_PARTIAL_ROUNDS - 1) + 1; - -pub(crate) const fn col_permutation_full_second_mid_sbox(round: usize, i: usize) -> usize { - debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS); - debug_assert!(i < SPONGE_WIDTH); - START_PERMUTATION_FULL_SECOND + 2 * round * SPONGE_WIDTH + i -} - -pub(crate) const fn col_permutation_full_second_after_mds(round: usize, i: usize) -> usize { - debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS); - debug_assert!(i < SPONGE_WIDTH); - START_PERMUTATION_FULL_SECOND + (2 * round + 1) * SPONGE_WIDTH + i -} - -pub(crate) const fn col_permutation_input(i: usize) -> usize { - debug_assert!(i < SPONGE_WIDTH); - START_PERMUTATION_UNIT + i -} - -pub(crate) const fn col_permutation_output(i: usize) -> usize { - debug_assert!(i < SPONGE_WIDTH); - col_permutation_full_second_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, i) -} - -const END_PERMUTATION_UNIT: usize = col_permutation_output(SPONGE_WIDTH - 1); - //// MEMORY UNITS //// DECOMPOSITION UNITS +pub(crate) mod decomposition { -const START_DECOMPOSITION_UNITS: usize = END_PERMUTATION_UNIT + 1; + const START_UNITS: usize = super::permutation::END_UNIT + 1; -const NUM_DECOMPOSITION_UNITS: usize = 4; -/// The number of bits associated with a single decomposition unit. -const DECOMPOSITION_UNIT_BITS: usize = 32; -/// One column for the value being decomposed, plus one column per bit. -const DECOMPOSITION_UNIT_COLS: usize = 1 + DECOMPOSITION_UNIT_BITS; + const NUM_UNITS: usize = 4; + /// The number of bits associated with a single decomposition unit. + const UNIT_BITS: usize = 32; + /// One column for the value being decomposed, plus one column per bit. + const UNIT_COLS: usize = 1 + UNIT_BITS; -pub(crate) const fn col_decomposition_input(unit: usize) -> usize { - debug_assert!(unit < NUM_DECOMPOSITION_UNITS); - START_DECOMPOSITION_UNITS + unit * DECOMPOSITION_UNIT_COLS + pub const fn col_input(unit: usize) -> usize { + debug_assert!(unit < NUM_UNITS); + START_UNITS + unit * UNIT_COLS + } + + pub const fn col_bit(unit: usize, bit: usize) -> usize { + debug_assert!(unit < NUM_UNITS); + debug_assert!(bit < UNIT_BITS); + START_UNITS + unit * UNIT_COLS + 1 + bit + } + + pub(super) const END_UNITS: usize = START_UNITS + UNIT_COLS * NUM_UNITS; } -pub(crate) const fn col_decomposition_bit(unit: usize, bit: usize) -> usize { - debug_assert!(unit < NUM_DECOMPOSITION_UNITS); - debug_assert!(bit < DECOMPOSITION_UNIT_BITS); - START_DECOMPOSITION_UNITS + unit * DECOMPOSITION_UNIT_COLS + 1 + bit -} - -const END_DECOMPOSITION_UNITS: usize = - START_DECOMPOSITION_UNITS + DECOMPOSITION_UNIT_COLS * NUM_DECOMPOSITION_UNITS; - -pub(crate) const NUM_COLUMNS: usize = END_DECOMPOSITION_UNITS; +pub(crate) const NUM_COLUMNS: usize = decomposition::END_UNITS; diff --git a/system_zero/src/permutation_unit.rs b/system_zero/src/permutation_unit.rs index 7f12b9ce..e15474e4 100644 --- a/system_zero/src/permutation_unit.rs +++ b/system_zero/src/permutation_unit.rs @@ -8,15 +8,11 @@ use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsume use starky::vars::StarkEvaluationTargets; use starky::vars::StarkEvaluationVars; -use crate::column_layout::{ - col_permutation_full_first_after_mds as col_full_1st_after_mds, - col_permutation_full_first_mid_sbox as col_full_1st_mid_sbox, - col_permutation_full_second_after_mds as col_full_2nd_after_mds, - col_permutation_full_second_mid_sbox as col_full_2nd_mid_sbox, - col_permutation_input as col_input, - col_permutation_partial_after_sbox as col_partial_after_sbox, - col_permutation_partial_mid_sbox as col_partial_mid_sbox, NUM_COLUMNS, +use crate::column_layout::permutation::{ + col_full_first_after_mds, col_full_first_mid_sbox, col_full_second_after_mds, + col_full_second_mid_sbox, col_input, col_partial_after_sbox, col_partial_mid_sbox, }; +use crate::column_layout::NUM_COLUMNS; use crate::public_input_layout::NUM_PUBLIC_INPUTS; use crate::system_zero::SystemZero; @@ -75,14 +71,14 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { let state_cubed = state[i].cube(); - values[col_full_1st_mid_sbox(r, i)] = state_cubed; + values[col_full_first_mid_sbox(r, i)] = state_cubed; state[i] *= state_cubed.square(); // Form state ** 7. } state = F::mds_layer(&state); for i in 0..SPONGE_WIDTH { - values[col_full_1st_after_mds(r, i)] = state[i]; + values[col_full_first_after_mds(r, i)] = state[i]; } } @@ -102,14 +98,14 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { let state_cubed = state[i].cube(); - values[col_full_2nd_mid_sbox(r, i)] = state_cubed; + values[col_full_second_mid_sbox(r, i)] = state_cubed; state[i] *= state_cubed.square(); // Form state ** 7. } state = F::mds_layer(&state); for i in 0..SPONGE_WIDTH { - values[col_full_2nd_after_mds(r, i)] = state[i]; + values[col_full_second_after_mds(r, i)] = state[i]; } } } @@ -136,8 +132,8 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { let state_cubed = state[i] * state[i].square(); yield_constr - .constraint_wrapping(state_cubed - local_values[col_full_1st_mid_sbox(r, i)]); - let state_cubed = local_values[col_full_1st_mid_sbox(r, i)]; + .constraint_wrapping(state_cubed - local_values[col_full_first_mid_sbox(r, i)]); + let state_cubed = local_values[col_full_first_mid_sbox(r, i)]; state[i] *= state_cubed.square(); // Form state ** 7. } @@ -145,8 +141,8 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { yield_constr - .constraint_wrapping(state[i] - local_values[col_full_1st_after_mds(r, i)]); - state[i] = local_values[col_full_1st_after_mds(r, i)]; + .constraint_wrapping(state[i] - local_values[col_full_first_after_mds(r, i)]); + state[i] = local_values[col_full_first_after_mds(r, i)]; } } @@ -168,9 +164,10 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { let state_cubed = state[i] * state[i].square(); - yield_constr - .constraint_wrapping(state_cubed - local_values[col_full_2nd_mid_sbox(r, i)]); - let state_cubed = local_values[col_full_2nd_mid_sbox(r, i)]; + yield_constr.constraint_wrapping( + state_cubed - local_values[col_full_second_mid_sbox(r, i)], + ); + let state_cubed = local_values[col_full_second_mid_sbox(r, i)]; state[i] *= state_cubed.square(); // Form state ** 7. } @@ -178,8 +175,8 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { yield_constr - .constraint_wrapping(state[i] - local_values[col_full_2nd_after_mds(r, i)]); - state[i] = local_values[col_full_2nd_after_mds(r, i)]; + .constraint_wrapping(state[i] - local_values[col_full_second_after_mds(r, i)]); + state[i] = local_values[col_full_second_after_mds(r, i)]; } } } @@ -204,9 +201,9 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { let state_cubed = builder.cube_extension(state[i]); let diff = - builder.sub_extension(state_cubed, local_values[col_full_1st_mid_sbox(r, i)]); + builder.sub_extension(state_cubed, local_values[col_full_first_mid_sbox(r, i)]); yield_constr.constraint_wrapping(builder, diff); - let state_cubed = local_values[col_full_1st_mid_sbox(r, i)]; + let state_cubed = local_values[col_full_first_mid_sbox(r, i)]; state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); // Form state ** 7. } @@ -215,9 +212,9 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { let diff = - builder.sub_extension(state[i], local_values[col_full_1st_after_mds(r, i)]); + builder.sub_extension(state[i], local_values[col_full_first_after_mds(r, i)]); yield_constr.constraint_wrapping(builder, diff); - state[i] = local_values[col_full_1st_after_mds(r, i)]; + state[i] = local_values[col_full_first_after_mds(r, i)]; } } @@ -245,10 +242,10 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { let state_cubed = builder.cube_extension(state[i]); - let diff = - builder.sub_extension(state_cubed, local_values[col_full_2nd_mid_sbox(r, i)]); + let diff = builder + .sub_extension(state_cubed, local_values[col_full_second_mid_sbox(r, i)]); yield_constr.constraint_wrapping(builder, diff); - let state_cubed = local_values[col_full_2nd_mid_sbox(r, i)]; + let state_cubed = local_values[col_full_second_mid_sbox(r, i)]; state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); // Form state ** 7. } @@ -257,9 +254,9 @@ impl, const D: usize> SystemZero { for i in 0..SPONGE_WIDTH { let diff = - builder.sub_extension(state[i], local_values[col_full_2nd_after_mds(r, i)]); + builder.sub_extension(state[i], local_values[col_full_second_after_mds(r, i)]); yield_constr.constraint_wrapping(builder, diff); - state[i] = local_values[col_full_2nd_after_mds(r, i)]; + state[i] = local_values[col_full_second_after_mds(r, i)]; } } } @@ -275,9 +272,8 @@ mod tests { use starky::constraint_consumer::ConstraintConsumer; use starky::vars::StarkEvaluationVars; - use crate::column_layout::{ - col_permutation_input as col_input, col_permutation_output as col_output, NUM_COLUMNS, - }; + use crate::column_layout::permutation::{col_input, col_output}; + use crate::column_layout::NUM_COLUMNS; use crate::permutation_unit::SPONGE_WIDTH; use crate::public_input_layout::NUM_PUBLIC_INPUTS; use crate::system_zero::SystemZero; From 8262389edda8a4916fe45bc79437e7cf91626580 Mon Sep 17 00:00:00 2001 From: BGluth Date: Wed, 9 Feb 2022 10:23:07 -0700 Subject: [PATCH 02/15] Added `Debug`, `Clone`, and `Copy` to ecdsa types --- plonky2/src/curve/ecdsa.rs | 4 ++++ plonky2/src/gadgets/ecdsa.rs | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/plonky2/src/curve/ecdsa.rs b/plonky2/src/curve/ecdsa.rs index c84c4c10..3a5d3c7a 100644 --- a/plonky2/src/curve/ecdsa.rs +++ b/plonky2/src/curve/ecdsa.rs @@ -2,12 +2,16 @@ use crate::curve::curve_msm::msm_parallel; use crate::curve::curve_types::{base_to_scalar, AffinePoint, Curve, CurveScalar}; use crate::field::field_types::Field; +#[derive(Copy, Clone, Debug)] pub struct ECDSASignature { pub r: C::ScalarField, pub s: C::ScalarField, } +#[derive(Copy, Clone, Debug)] pub struct ECDSASecretKey(pub C::ScalarField); + +#[derive(Copy, Clone, Debug)] pub struct ECDSAPublicKey(pub AffinePoint); pub fn sign_message(msg: C::ScalarField, sk: ECDSASecretKey) -> ECDSASignature { diff --git a/plonky2/src/gadgets/ecdsa.rs b/plonky2/src/gadgets/ecdsa.rs index eba04d85..0a95e189 100644 --- a/plonky2/src/gadgets/ecdsa.rs +++ b/plonky2/src/gadgets/ecdsa.rs @@ -7,9 +7,13 @@ use crate::gadgets::nonnative::NonNativeTarget; use crate::hash::hash_types::RichField; use crate::plonk::circuit_builder::CircuitBuilder; +#[derive(Clone, Debug)] pub struct ECDSASecretKeyTarget(NonNativeTarget); + +#[derive(Clone, Debug)] pub struct ECDSAPublicKeyTarget(AffinePointTarget); +#[derive(Clone, Debug)] pub struct ECDSASignatureTarget { pub r: NonNativeTarget, pub s: NonNativeTarget, From adf5444f3fbd7e5011c3b0cd7f5f9a0efdbdb764 Mon Sep 17 00:00:00 2001 From: BGluth Date: Wed, 9 Feb 2022 18:31:58 -0700 Subject: [PATCH 03/15] `from_partial` (non-target) now takes in a slice - Doesn't need to take in a `Vec`. --- plonky2/src/hash/hash_types.rs | 14 ++++++-------- plonky2/src/hash/hashing.rs | 4 ++-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/plonky2/src/hash/hash_types.rs b/plonky2/src/hash/hash_types.rs index ed6fca43..8187979b 100644 --- a/plonky2/src/hash/hash_types.rs +++ b/plonky2/src/hash/hash_types.rs @@ -31,14 +31,12 @@ impl HashOut { } } - pub fn from_partial(mut elements: Vec) -> Self { - debug_assert!(elements.len() <= 4); - while elements.len() < 4 { - elements.push(F::ZERO); - } - Self { - elements: [elements[0], elements[1], elements[2], elements[3]], - } + pub fn from_partial(elements_in: &[F]) -> Self { + debug_assert!(elements_in.len() <= 4); + + let mut elements = [F::ZERO; 4]; + elements[0..elements_in.len()].copy_from_slice(elements_in); + Self { elements } } pub fn rand_from_rng(rng: &mut R) -> Self { diff --git a/plonky2/src/hash/hashing.rs b/plonky2/src/hash/hashing.rs index ea205654..eb238e51 100644 --- a/plonky2/src/hash/hashing.rs +++ b/plonky2/src/hash/hashing.rs @@ -14,11 +14,11 @@ pub const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; /// Hash the vector if necessary to reduce its length to ~256 bits. If it already fits, this is a /// no-op. -pub fn hash_or_noop>(inputs: Vec) -> HashOut { +pub fn hash_or_noop>(inputs: &[F]) -> HashOut { if inputs.len() <= 4 { HashOut::from_partial(inputs) } else { - hash_n_to_hash_no_pad::(&inputs) + hash_n_to_hash_no_pad::(inputs) } } From cfe52ad6040a3584b3744e486257129a7c01baff Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Wed, 9 Feb 2022 21:50:18 -0800 Subject: [PATCH 04/15] Add `PrimeField`, `PrimeField64` traits (#457) * Add PrimeField, PrimeField64 traits * fix * fixes * fix * `to_biguint` -> `to_canonical_biguint` --- field/src/extension_field/quadratic.rs | 4 -- field/src/extension_field/quartic.rs | 8 --- field/src/field_types.rs | 20 +++---- field/src/goldilocks_field.rs | 42 ++++++++------- field/src/inversion.rs | 4 +- field/src/prime_field_testing.rs | 6 +-- field/src/secp256k1_base.rs | 36 +++++++------ field/src/secp256k1_scalar.rs | 36 +++++++------ plonky2/src/curve/curve_msm.rs | 6 ++- plonky2/src/curve/curve_multiplication.rs | 3 +- plonky2/src/curve/curve_types.rs | 10 ++-- plonky2/src/curve/secp256k1.rs | 3 +- plonky2/src/gadgets/nonnative.rs | 66 +++++++++++++---------- plonky2/src/gates/assert_le.rs | 3 +- plonky2/src/gates/comparison.rs | 3 +- plonky2/src/gates/subtraction_u32.rs | 3 +- plonky2/src/hash/hash_types.rs | 4 +- plonky2/src/hash/poseidon.rs | 6 +-- plonky2/src/hash/poseidon_goldilocks.rs | 3 +- plonky2/src/iop/generator.rs | 6 +-- plonky2/src/iop/witness.rs | 14 +++-- plonky2/src/util/serialization.rs | 6 +-- waksman/src/sorting.rs | 2 +- 23 files changed, 159 insertions(+), 135 deletions(-) diff --git a/field/src/extension_field/quadratic.rs b/field/src/extension_field/quadratic.rs index e072d323..488304d2 100644 --- a/field/src/extension_field/quadratic.rs +++ b/field/src/extension_field/quadratic.rs @@ -95,10 +95,6 @@ impl> Field for QuadraticExtension { Self([F::from_biguint(low), F::from_biguint(high)]) } - fn to_biguint(&self) -> BigUint { - self.0[0].to_biguint() + F::order() * self.0[1].to_biguint() - } - fn from_canonical_u64(n: u64) -> Self { F::from_canonical_u64(n).into() } diff --git a/field/src/extension_field/quartic.rs b/field/src/extension_field/quartic.rs index 4e9cebf9..7b4a6950 100644 --- a/field/src/extension_field/quartic.rs +++ b/field/src/extension_field/quartic.rs @@ -107,14 +107,6 @@ impl> Field for QuarticExtension { ]) } - fn to_biguint(&self) -> BigUint { - let mut result = self.0[3].to_biguint(); - result = result * F::order() + self.0[2].to_biguint(); - result = result * F::order() + self.0[1].to_biguint(); - result = result * F::order() + self.0[0].to_biguint(); - result - } - fn from_canonical_u64(n: u64) -> Self { F::from_canonical_u64(n).into() } diff --git a/field/src/field_types.rs b/field/src/field_types.rs index 65d5bf21..95696475 100644 --- a/field/src/field_types.rs +++ b/field/src/field_types.rs @@ -268,9 +268,6 @@ pub trait Field: // Rename to `from_noncanonical_biguint` and have it return `n % Self::characteristic()`. fn from_biguint(n: BigUint) -> Self; - // TODO: Move to a new `PrimeField` trait. - fn to_biguint(&self) -> BigUint; - /// Returns `n`. Assumes that `n` is already in canonical form, i.e. `n < Self::order()`. // TODO: Should probably be unsafe. fn from_canonical_u64(n: u64) -> Self; @@ -407,16 +404,14 @@ pub trait Field: } } +pub trait PrimeField: Field { + fn to_canonical_biguint(&self) -> BigUint; +} + /// A finite field of order less than 2^64. pub trait Field64: Field { const ORDER: u64; - // TODO: Only well-defined for prime 64-bit fields. Move to a new PrimeField64 trait? - fn to_canonical_u64(&self) -> u64; - - // TODO: Only well-defined for prime 64-bit fields. Move to a new PrimeField64 trait? - fn to_noncanonical_u64(&self) -> u64; - /// Returns `x % Self::CHARACTERISTIC`. // TODO: Move to `Field`. fn from_noncanonical_u64(n: u64) -> Self; @@ -456,6 +451,13 @@ pub trait Field64: Field { } } +/// A finite field of prime order less than 2^64. +pub trait PrimeField64: PrimeField + Field64 { + fn to_canonical_u64(&self) -> u64; + + fn to_noncanonical_u64(&self) -> u64; +} + /// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`. #[derive(Clone)] pub struct Powers { diff --git a/field/src/goldilocks_field.rs b/field/src/goldilocks_field.rs index a121b4d2..6c033bb2 100644 --- a/field/src/goldilocks_field.rs +++ b/field/src/goldilocks_field.rs @@ -12,7 +12,7 @@ use serde::{Deserialize, Serialize}; use crate::extension_field::quadratic::QuadraticExtension; use crate::extension_field::quartic::QuarticExtension; use crate::extension_field::{Extendable, Frobenius}; -use crate::field_types::{Field, Field64}; +use crate::field_types::{Field, Field64, PrimeField, PrimeField64}; use crate::inversion::try_inverse_u64; const EPSILON: u64 = (1 << 32) - 1; @@ -98,10 +98,6 @@ impl Field for GoldilocksField { Self(n.mod_floor(&Self::order()).to_u64_digits()[0]) } - fn to_biguint(&self) -> BigUint { - self.to_canonical_u64().into() - } - #[inline] fn from_canonical_u64(n: u64) -> Self { debug_assert!(n < Self::ORDER); @@ -123,23 +119,15 @@ impl Field for GoldilocksField { } } +impl PrimeField for GoldilocksField { + fn to_canonical_biguint(&self) -> BigUint { + self.to_canonical_u64().into() + } +} + impl Field64 for GoldilocksField { const ORDER: u64 = 0xFFFFFFFF00000001; - #[inline] - fn to_canonical_u64(&self) -> u64 { - let mut c = self.0; - // We only need one condition subtraction, since 2 * ORDER would not fit in a u64. - if c >= Self::ORDER { - c -= Self::ORDER; - } - c - } - - fn to_noncanonical_u64(&self) -> u64 { - self.0 - } - #[inline] fn from_noncanonical_u64(n: u64) -> Self { Self(n) @@ -160,6 +148,22 @@ impl Field64 for GoldilocksField { } } +impl PrimeField64 for GoldilocksField { + #[inline] + fn to_canonical_u64(&self) -> u64 { + let mut c = self.0; + // We only need one condition subtraction, since 2 * ORDER would not fit in a u64. + if c >= Self::ORDER { + c -= Self::ORDER; + } + c + } + + fn to_noncanonical_u64(&self) -> u64 { + self.0 + } +} + impl Neg for GoldilocksField { type Output = Self; diff --git a/field/src/inversion.rs b/field/src/inversion.rs index 10c02879..5eabc45c 100644 --- a/field/src/inversion.rs +++ b/field/src/inversion.rs @@ -1,4 +1,4 @@ -use crate::field_types::Field64; +use crate::field_types::PrimeField64; /// This is a 'safe' iteration for the modular inversion algorithm. It /// is safe in the sense that it will produce the right answer even @@ -63,7 +63,7 @@ unsafe fn unsafe_iteration(f: &mut u64, g: &mut u64, c: &mut i128, d: &mut i128, /// Elliptic and Hyperelliptic Cryptography, Algorithms 11.6 /// and 11.12. #[allow(clippy::many_single_char_names)] -pub(crate) fn try_inverse_u64(x: &F) -> Option { +pub(crate) fn try_inverse_u64(x: &F) -> Option { let mut f = x.to_noncanonical_u64(); let mut g = F::ORDER; // NB: These two are very rarely such that their absolute diff --git a/field/src/prime_field_testing.rs b/field/src/prime_field_testing.rs index 772336e9..24d5e3c7 100644 --- a/field/src/prime_field_testing.rs +++ b/field/src/prime_field_testing.rs @@ -1,4 +1,4 @@ -use crate::field_types::Field64; +use crate::field_types::PrimeField64; /// Generates a series of non-negative integers less than `modulus` which cover a range of /// interesting test values. @@ -19,7 +19,7 @@ pub fn test_inputs(modulus: u64) -> Vec { /// word_bits)` and panic if the two resulting vectors differ. pub fn run_unaryop_test_cases(op: UnaryOp, expected_op: ExpectedOp) where - F: Field64, + F: PrimeField64, UnaryOp: Fn(F) -> F, ExpectedOp: Fn(u64) -> u64, { @@ -43,7 +43,7 @@ where /// Apply the binary functions `op` and `expected_op` to each pair of inputs. pub fn run_binaryop_test_cases(op: BinaryOp, expected_op: ExpectedOp) where - F: Field64, + F: PrimeField64, BinaryOp: Fn(F, F) -> F, ExpectedOp: Fn(u64, u64) -> u64, { diff --git a/field/src/secp256k1_base.rs b/field/src/secp256k1_base.rs index 23702420..1972aed7 100644 --- a/field/src/secp256k1_base.rs +++ b/field/src/secp256k1_base.rs @@ -10,7 +10,7 @@ use num::{Integer, One}; use rand::Rng; use serde::{Deserialize, Serialize}; -use crate::field_types::Field; +use crate::field_types::{Field, PrimeField}; /// The base field of the secp256k1 elliptic curve. /// @@ -42,7 +42,7 @@ impl Default for Secp256K1Base { impl PartialEq for Secp256K1Base { fn eq(&self, other: &Self) -> bool { - self.to_biguint() == other.to_biguint() + self.to_canonical_biguint() == other.to_canonical_biguint() } } @@ -50,19 +50,19 @@ impl Eq for Secp256K1Base {} impl Hash for Secp256K1Base { fn hash(&self, state: &mut H) { - self.to_biguint().hash(state) + self.to_canonical_biguint().hash(state) } } impl Display for Secp256K1Base { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Display::fmt(&self.to_biguint(), f) + Display::fmt(&self.to_canonical_biguint(), f) } } impl Debug for Secp256K1Base { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Debug::fmt(&self.to_biguint(), f) + Debug::fmt(&self.to_canonical_biguint(), f) } } @@ -107,14 +107,6 @@ impl Field for Secp256K1Base { Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one()))) } - fn to_biguint(&self) -> BigUint { - let mut result = biguint_from_array(self.0); - if result >= Self::order() { - result -= Self::order(); - } - result - } - fn from_biguint(val: BigUint) -> Self { Self( val.to_u64_digits() @@ -146,6 +138,16 @@ impl Field for Secp256K1Base { } } +impl PrimeField for Secp256K1Base { + fn to_canonical_biguint(&self) -> BigUint { + let mut result = biguint_from_array(self.0); + if result >= Self::order() { + result -= Self::order(); + } + result + } +} + impl Neg for Secp256K1Base { type Output = Self; @@ -154,7 +156,7 @@ impl Neg for Secp256K1Base { if self.is_zero() { Self::ZERO } else { - Self::from_biguint(Self::order() - self.to_biguint()) + Self::from_biguint(Self::order() - self.to_canonical_biguint()) } } } @@ -164,7 +166,7 @@ impl Add for Secp256K1Base { #[inline] fn add(self, rhs: Self) -> Self { - let mut result = self.to_biguint() + rhs.to_biguint(); + let mut result = self.to_canonical_biguint() + rhs.to_canonical_biguint(); if result >= Self::order() { result -= Self::order(); } @@ -207,7 +209,9 @@ impl Mul for Secp256K1Base { #[inline] fn mul(self, rhs: Self) -> Self { - Self::from_biguint((self.to_biguint() * rhs.to_biguint()).mod_floor(&Self::order())) + Self::from_biguint( + (self.to_canonical_biguint() * rhs.to_canonical_biguint()).mod_floor(&Self::order()), + ) } } diff --git a/field/src/secp256k1_scalar.rs b/field/src/secp256k1_scalar.rs index f10892af..1e506426 100644 --- a/field/src/secp256k1_scalar.rs +++ b/field/src/secp256k1_scalar.rs @@ -11,7 +11,7 @@ use num::{Integer, One}; use rand::Rng; use serde::{Deserialize, Serialize}; -use crate::field_types::Field; +use crate::field_types::{Field, PrimeField}; /// The base field of the secp256k1 elliptic curve. /// @@ -45,7 +45,7 @@ impl Default for Secp256K1Scalar { impl PartialEq for Secp256K1Scalar { fn eq(&self, other: &Self) -> bool { - self.to_biguint() == other.to_biguint() + self.to_canonical_biguint() == other.to_canonical_biguint() } } @@ -53,19 +53,19 @@ impl Eq for Secp256K1Scalar {} impl Hash for Secp256K1Scalar { fn hash(&self, state: &mut H) { - self.to_biguint().hash(state) + self.to_canonical_biguint().hash(state) } } impl Display for Secp256K1Scalar { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Display::fmt(&self.to_biguint(), f) + Display::fmt(&self.to_canonical_biguint(), f) } } impl Debug for Secp256K1Scalar { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - Debug::fmt(&self.to_biguint(), f) + Debug::fmt(&self.to_canonical_biguint(), f) } } @@ -116,14 +116,6 @@ impl Field for Secp256K1Scalar { Some(self.exp_biguint(&(Self::order() - BigUint::one() - BigUint::one()))) } - fn to_biguint(&self) -> BigUint { - let mut result = biguint_from_array(self.0); - if result >= Self::order() { - result -= Self::order(); - } - result - } - fn from_biguint(val: BigUint) -> Self { Self( val.to_u64_digits() @@ -155,6 +147,16 @@ impl Field for Secp256K1Scalar { } } +impl PrimeField for Secp256K1Scalar { + fn to_canonical_biguint(&self) -> BigUint { + let mut result = biguint_from_array(self.0); + if result >= Self::order() { + result -= Self::order(); + } + result + } +} + impl Neg for Secp256K1Scalar { type Output = Self; @@ -163,7 +165,7 @@ impl Neg for Secp256K1Scalar { if self.is_zero() { Self::ZERO } else { - Self::from_biguint(Self::order() - self.to_biguint()) + Self::from_biguint(Self::order() - self.to_canonical_biguint()) } } } @@ -173,7 +175,7 @@ impl Add for Secp256K1Scalar { #[inline] fn add(self, rhs: Self) -> Self { - let mut result = self.to_biguint() + rhs.to_biguint(); + let mut result = self.to_canonical_biguint() + rhs.to_canonical_biguint(); if result >= Self::order() { result -= Self::order(); } @@ -216,7 +218,9 @@ impl Mul for Secp256K1Scalar { #[inline] fn mul(self, rhs: Self) -> Self { - Self::from_biguint((self.to_biguint() * rhs.to_biguint()).mod_floor(&Self::order())) + Self::from_biguint( + (self.to_canonical_biguint() * rhs.to_canonical_biguint()).mod_floor(&Self::order()), + ) } } diff --git a/plonky2/src/curve/curve_msm.rs b/plonky2/src/curve/curve_msm.rs index 388c0321..4c274c1c 100644 --- a/plonky2/src/curve/curve_msm.rs +++ b/plonky2/src/curve/curve_msm.rs @@ -1,5 +1,6 @@ use itertools::Itertools; use plonky2_field::field_types::Field; +use plonky2_field::field_types::PrimeField; use rayon::prelude::*; use crate::curve::curve_summation::affine_multisummation_best; @@ -160,7 +161,7 @@ pub(crate) fn to_digits(x: &C::ScalarField, w: usize) -> Vec { // Convert x to a bool array. let x_canonical: Vec<_> = x - .to_biguint() + .to_canonical_biguint() .to_u64_digits() .iter() .cloned() @@ -187,6 +188,7 @@ pub(crate) fn to_digits(x: &C::ScalarField, w: usize) -> Vec { mod tests { use num::BigUint; use plonky2_field::field_types::Field; + use plonky2_field::field_types::PrimeField; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; use crate::curve::curve_msm::{msm_execute, msm_precompute, to_digits}; @@ -206,7 +208,7 @@ mod tests { 0b11111111111111111111111111111111, ]; let x = Secp256K1Scalar::from_biguint(BigUint::from_slice(&x_canonical)); - assert_eq!(x.to_biguint().to_u32_digits(), x_canonical); + assert_eq!(x.to_canonical_biguint().to_u32_digits(), x_canonical); assert_eq!( to_digits::(&x, 17), vec![ diff --git a/plonky2/src/curve/curve_multiplication.rs b/plonky2/src/curve/curve_multiplication.rs index 30da4973..c6fbbd83 100644 --- a/plonky2/src/curve/curve_multiplication.rs +++ b/plonky2/src/curve/curve_multiplication.rs @@ -1,6 +1,7 @@ use std::ops::Mul; use plonky2_field::field_types::Field; +use plonky2_field::field_types::PrimeField; use crate::curve::curve_types::{Curve, CurveScalar, ProjectivePoint}; @@ -88,7 +89,7 @@ fn to_digits(x: &C::ScalarField) -> Vec { ); let digits_per_u64 = 64 / WINDOW_BITS; let mut digits = Vec::with_capacity(digits_per_scalar::()); - for limb in x.to_biguint().to_u64_digits() { + for limb in x.to_canonical_biguint().to_u64_digits() { for j in 0..digits_per_u64 { digits.push((limb >> (j * WINDOW_BITS) as u64) % BASE as u64); } diff --git a/plonky2/src/curve/curve_types.rs b/plonky2/src/curve/curve_types.rs index 9599f6fe..0a9e8711 100644 --- a/plonky2/src/curve/curve_types.rs +++ b/plonky2/src/curve/curve_types.rs @@ -1,7 +1,7 @@ use std::fmt::Debug; use std::ops::Neg; -use plonky2_field::field_types::Field; +use plonky2_field::field_types::{Field, PrimeField}; use plonky2_field::ops::Square; // To avoid implementation conflicts from associated types, @@ -10,8 +10,8 @@ pub struct CurveScalar(pub ::ScalarField); /// A short Weierstrass curve. pub trait Curve: 'static + Sync + Sized + Copy + Debug { - type BaseField: Field; - type ScalarField: Field; + type BaseField: PrimeField; + type ScalarField: PrimeField; const A: Self::BaseField; const B: Self::BaseField; @@ -261,9 +261,9 @@ impl Neg for ProjectivePoint { } pub fn base_to_scalar(x: C::BaseField) -> C::ScalarField { - C::ScalarField::from_biguint(x.to_biguint()) + C::ScalarField::from_biguint(x.to_canonical_biguint()) } pub fn scalar_to_base(x: C::ScalarField) -> C::BaseField { - C::BaseField::from_biguint(x.to_biguint()) + C::BaseField::from_biguint(x.to_canonical_biguint()) } diff --git a/plonky2/src/curve/secp256k1.rs b/plonky2/src/curve/secp256k1.rs index d9039719..6a460735 100644 --- a/plonky2/src/curve/secp256k1.rs +++ b/plonky2/src/curve/secp256k1.rs @@ -40,6 +40,7 @@ const SECP256K1_GENERATOR_Y: Secp256K1Base = Secp256K1Base([ mod tests { use num::BigUint; use plonky2_field::field_types::Field; + use plonky2_field::field_types::PrimeField; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; @@ -86,7 +87,7 @@ mod tests { ) -> ProjectivePoint { let mut g = rhs; let mut sum = ProjectivePoint::ZERO; - for limb in lhs.to_biguint().to_u64_digits().iter() { + for limb in lhs.to_canonical_biguint().to_u64_digits().iter() { for j in 0..64 { if (limb >> j & 1u64) != 0u64 { sum = sum + g; diff --git a/plonky2/src/gadgets/nonnative.rs b/plonky2/src/gadgets/nonnative.rs index 245b0403..3f8d29e8 100644 --- a/plonky2/src/gadgets/nonnative.rs +++ b/plonky2/src/gadgets/nonnative.rs @@ -1,6 +1,7 @@ use std::marker::PhantomData; use num::{BigUint, Integer, One, Zero}; +use plonky2_field::field_types::PrimeField; use plonky2_field::{extension_field::Extendable, field_types::Field}; use plonky2_util::ceil_div_usize; @@ -34,12 +35,12 @@ impl, const D: usize> CircuitBuilder { x.value.clone() } - pub fn constant_nonnative(&mut self, x: FF) -> NonNativeTarget { - let x_biguint = self.constant_biguint(&x.to_biguint()); + pub fn constant_nonnative(&mut self, x: FF) -> NonNativeTarget { + let x_biguint = self.constant_biguint(&x.to_canonical_biguint()); self.biguint_to_nonnative(&x_biguint) } - pub fn zero_nonnative(&mut self) -> NonNativeTarget { + pub fn zero_nonnative(&mut self) -> NonNativeTarget { self.constant_nonnative(FF::ZERO) } @@ -62,7 +63,7 @@ impl, const D: usize> CircuitBuilder { } } - pub fn add_nonnative( + pub fn add_nonnative( &mut self, a: &NonNativeTarget, b: &NonNativeTarget, @@ -105,7 +106,7 @@ impl, const D: usize> CircuitBuilder { } } - pub fn add_many_nonnative( + pub fn add_many_nonnative( &mut self, to_add: &[NonNativeTarget], ) -> NonNativeTarget { @@ -149,7 +150,7 @@ impl, const D: usize> CircuitBuilder { } // Subtract two `NonNativeTarget`s. - pub fn sub_nonnative( + pub fn sub_nonnative( &mut self, a: &NonNativeTarget, b: &NonNativeTarget, @@ -177,7 +178,7 @@ impl, const D: usize> CircuitBuilder { diff } - pub fn mul_nonnative( + pub fn mul_nonnative( &mut self, a: &NonNativeTarget, b: &NonNativeTarget, @@ -208,7 +209,7 @@ impl, const D: usize> CircuitBuilder { prod } - pub fn mul_many_nonnative( + pub fn mul_many_nonnative( &mut self, to_mul: &[NonNativeTarget], ) -> NonNativeTarget { @@ -223,14 +224,20 @@ impl, const D: usize> CircuitBuilder { accumulator } - pub fn neg_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { + pub fn neg_nonnative( + &mut self, + x: &NonNativeTarget, + ) -> NonNativeTarget { let zero_target = self.constant_biguint(&BigUint::zero()); let zero_ff = self.biguint_to_nonnative(&zero_target); self.sub_nonnative(&zero_ff, x) } - pub fn inv_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { + pub fn inv_nonnative( + &mut self, + x: &NonNativeTarget, + ) -> NonNativeTarget { let num_limbs = x.value.num_limbs(); let inv_biguint = self.add_virtual_biguint_target(num_limbs); let div = self.add_virtual_biguint_target(num_limbs); @@ -307,7 +314,7 @@ impl, const D: usize> CircuitBuilder { } #[derive(Debug)] -struct NonNativeAdditionGenerator, const D: usize, FF: Field> { +struct NonNativeAdditionGenerator, const D: usize, FF: PrimeField> { a: NonNativeTarget, b: NonNativeTarget, sum: NonNativeTarget, @@ -315,7 +322,7 @@ struct NonNativeAdditionGenerator, const D: usize, _phantom: PhantomData, } -impl, const D: usize, FF: Field> SimpleGenerator +impl, const D: usize, FF: PrimeField> SimpleGenerator for NonNativeAdditionGenerator { fn dependencies(&self) -> Vec { @@ -332,8 +339,8 @@ impl, const D: usize, FF: Field> SimpleGenerator fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { let a = witness.get_nonnative_target(self.a.clone()); let b = witness.get_nonnative_target(self.b.clone()); - let a_biguint = a.to_biguint(); - let b_biguint = b.to_biguint(); + let a_biguint = a.to_canonical_biguint(); + let b_biguint = b.to_canonical_biguint(); let sum_biguint = a_biguint + b_biguint; let modulus = FF::order(); let (overflow, sum_reduced) = if sum_biguint > modulus { @@ -348,14 +355,15 @@ impl, const D: usize, FF: Field> SimpleGenerator } #[derive(Debug)] -struct NonNativeMultipleAddsGenerator, const D: usize, FF: Field> { +struct NonNativeMultipleAddsGenerator, const D: usize, FF: PrimeField> +{ summands: Vec>, sum: NonNativeTarget, overflow: U32Target, _phantom: PhantomData, } -impl, const D: usize, FF: Field> SimpleGenerator +impl, const D: usize, FF: PrimeField> SimpleGenerator for NonNativeMultipleAddsGenerator { fn dependencies(&self) -> Vec { @@ -373,7 +381,7 @@ impl, const D: usize, FF: Field> SimpleGenerator .collect(); let summand_biguints: Vec<_> = summands .iter() - .map(|summand| summand.to_biguint()) + .map(|summand| summand.to_canonical_biguint()) .collect(); let sum_biguint = summand_biguints @@ -398,7 +406,7 @@ struct NonNativeSubtractionGenerator, const D: usiz _phantom: PhantomData, } -impl, const D: usize, FF: Field> SimpleGenerator +impl, const D: usize, FF: PrimeField> SimpleGenerator for NonNativeSubtractionGenerator { fn dependencies(&self) -> Vec { @@ -415,8 +423,8 @@ impl, const D: usize, FF: Field> SimpleGenerator fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { let a = witness.get_nonnative_target(self.a.clone()); let b = witness.get_nonnative_target(self.b.clone()); - let a_biguint = a.to_biguint(); - let b_biguint = b.to_biguint(); + let a_biguint = a.to_canonical_biguint(); + let b_biguint = b.to_canonical_biguint(); let modulus = FF::order(); let (diff_biguint, overflow) = if a_biguint > b_biguint { @@ -439,7 +447,7 @@ struct NonNativeMultiplicationGenerator, const D: u _phantom: PhantomData, } -impl, const D: usize, FF: Field> SimpleGenerator +impl, const D: usize, FF: PrimeField> SimpleGenerator for NonNativeMultiplicationGenerator { fn dependencies(&self) -> Vec { @@ -456,8 +464,8 @@ impl, const D: usize, FF: Field> SimpleGenerator fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { let a = witness.get_nonnative_target(self.a.clone()); let b = witness.get_nonnative_target(self.b.clone()); - let a_biguint = a.to_biguint(); - let b_biguint = b.to_biguint(); + let a_biguint = a.to_canonical_biguint(); + let b_biguint = b.to_canonical_biguint(); let prod_biguint = a_biguint * b_biguint; @@ -470,14 +478,14 @@ impl, const D: usize, FF: Field> SimpleGenerator } #[derive(Debug)] -struct NonNativeInverseGenerator, const D: usize, FF: Field> { +struct NonNativeInverseGenerator, const D: usize, FF: PrimeField> { x: NonNativeTarget, inv: BigUintTarget, div: BigUintTarget, _phantom: PhantomData, } -impl, const D: usize, FF: Field> SimpleGenerator +impl, const D: usize, FF: PrimeField> SimpleGenerator for NonNativeInverseGenerator { fn dependencies(&self) -> Vec { @@ -488,8 +496,8 @@ impl, const D: usize, FF: Field> SimpleGenerator let x = witness.get_nonnative_target(self.x.clone()); let inv = x.inverse(); - let x_biguint = x.to_biguint(); - let inv_biguint = inv.to_biguint(); + let x_biguint = x.to_canonical_biguint(); + let inv_biguint = inv.to_canonical_biguint(); let prod = x_biguint * &inv_biguint; let modulus = FF::order(); let (div, _rem) = prod.div_rem(&modulus); @@ -502,7 +510,7 @@ impl, const D: usize, FF: Field> SimpleGenerator #[cfg(test)] mod tests { use anyhow::Result; - use plonky2_field::field_types::Field; + use plonky2_field::field_types::{Field, PrimeField}; use plonky2_field::secp256k1_base::Secp256K1Base; use crate::iop::witness::PartialWitness; @@ -587,7 +595,7 @@ mod tests { let x_ff = FF::rand(); let mut y_ff = FF::rand(); - while y_ff.to_biguint() > x_ff.to_biguint() { + while y_ff.to_canonical_biguint() > x_ff.to_canonical_biguint() { y_ff = FF::rand(); } let diff_ff = x_ff - y_ff; diff --git a/plonky2/src/gates/assert_le.rs b/plonky2/src/gates/assert_le.rs index c087a963..cec7274b 100644 --- a/plonky2/src/gates/assert_le.rs +++ b/plonky2/src/gates/assert_le.rs @@ -455,7 +455,8 @@ mod tests { use anyhow::Result; use plonky2_field::extension_field::quartic::QuarticExtension; - use plonky2_field::field_types::{Field, Field64}; + use plonky2_field::field_types::Field; + use plonky2_field::field_types::PrimeField64; use plonky2_field::goldilocks_field::GoldilocksField; use rand::Rng; diff --git a/plonky2/src/gates/comparison.rs b/plonky2/src/gates/comparison.rs index bc3e69b9..b1cf7b98 100644 --- a/plonky2/src/gates/comparison.rs +++ b/plonky2/src/gates/comparison.rs @@ -520,7 +520,8 @@ mod tests { use std::marker::PhantomData; use anyhow::Result; - use plonky2_field::field_types::{Field, Field64}; + use plonky2_field::field_types::Field; + use plonky2_field::field_types::PrimeField64; use plonky2_field::goldilocks_field::GoldilocksField; use rand::Rng; diff --git a/plonky2/src/gates/subtraction_u32.rs b/plonky2/src/gates/subtraction_u32.rs index 80bc03ed..b1e4d84f 100644 --- a/plonky2/src/gates/subtraction_u32.rs +++ b/plonky2/src/gates/subtraction_u32.rs @@ -338,7 +338,8 @@ mod tests { use anyhow::Result; use plonky2_field::extension_field::quartic::QuarticExtension; - use plonky2_field::field_types::{Field, Field64}; + use plonky2_field::field_types::Field; + use plonky2_field::field_types::PrimeField64; use plonky2_field::goldilocks_field::GoldilocksField; use rand::Rng; diff --git a/plonky2/src/hash/hash_types.rs b/plonky2/src/hash/hash_types.rs index ed6fca43..0a1cedd0 100644 --- a/plonky2/src/hash/hash_types.rs +++ b/plonky2/src/hash/hash_types.rs @@ -1,4 +1,4 @@ -use plonky2_field::field_types::{Field, Field64}; +use plonky2_field::field_types::{Field, PrimeField64}; use plonky2_field::goldilocks_field::GoldilocksField; use rand::Rng; use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -8,7 +8,7 @@ use crate::iop::target::Target; use crate::plonk::config::GenericHashOut; /// A prime order field with the features we need to use it as a base field in our argument system. -pub trait RichField: Field64 + Poseidon {} +pub trait RichField: PrimeField64 + Poseidon {} impl RichField for GoldilocksField {} diff --git a/plonky2/src/hash/poseidon.rs b/plonky2/src/hash/poseidon.rs index 08c2851a..09c5d2fc 100644 --- a/plonky2/src/hash/poseidon.rs +++ b/plonky2/src/hash/poseidon.rs @@ -2,7 +2,7 @@ //! https://eprint.iacr.org/2019/458.pdf use plonky2_field::extension_field::{Extendable, FieldExtension}; -use plonky2_field::field_types::{Field, Field64}; +use plonky2_field::field_types::{Field, PrimeField64}; use unroll::unroll_for_loops; use crate::gates::gate::Gate; @@ -35,7 +35,7 @@ fn add_u160_u128((x_lo, x_hi): (u128, u32), y: u128) -> (u128, u32) { } #[inline(always)] -fn reduce_u160((n_lo, n_hi): (u128, u32)) -> F { +fn reduce_u160((n_lo, n_hi): (u128, u32)) -> F { let n_lo_hi = (n_lo >> 64) as u64; let n_lo_lo = n_lo as u64; let reduced_hi: u64 = F::from_noncanonical_u96((n_lo_hi, n_hi)).to_noncanonical_u64(); @@ -148,7 +148,7 @@ pub const ALL_ROUND_CONSTANTS: [u64; MAX_WIDTH * N_ROUNDS] = [ ]; const WIDTH: usize = SPONGE_WIDTH; -pub trait Poseidon: Field64 { +pub trait Poseidon: PrimeField64 { // Total number of round constants required: width of the input // times number of rounds. const N_ROUND_CONSTANTS: usize = WIDTH * N_ROUNDS; diff --git a/plonky2/src/hash/poseidon_goldilocks.rs b/plonky2/src/hash/poseidon_goldilocks.rs index ab886847..7b82bb01 100644 --- a/plonky2/src/hash/poseidon_goldilocks.rs +++ b/plonky2/src/hash/poseidon_goldilocks.rs @@ -270,7 +270,8 @@ impl Poseidon for GoldilocksField { #[cfg(test)] mod tests { - use plonky2_field::field_types::{Field, Field64}; + use plonky2_field::field_types::Field; + use plonky2_field::field_types::PrimeField64; use plonky2_field::goldilocks_field::GoldilocksField as F; use crate::hash::poseidon::test_helpers::{check_consistency, check_test_vectors}; diff --git a/plonky2/src/iop/generator.rs b/plonky2/src/iop/generator.rs index 73978f5c..1569e889 100644 --- a/plonky2/src/iop/generator.rs +++ b/plonky2/src/iop/generator.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use num::BigUint; use plonky2_field::extension_field::{Extendable, FieldExtension}; -use plonky2_field::field_types::Field; +use plonky2_field::field_types::{Field, PrimeField}; use crate::gadgets::arithmetic_u32::U32Target; use crate::gadgets::biguint::BigUintTarget; @@ -180,8 +180,8 @@ impl GeneratedValues { } } - pub fn set_nonnative_target(&mut self, target: NonNativeTarget, value: FF) { - self.set_biguint_target(target.value, value.to_biguint()) + pub fn set_nonnative_target(&mut self, target: NonNativeTarget, value: FF) { + self.set_biguint_target(target.value, value.to_canonical_biguint()) } pub fn set_hash_target(&mut self, ht: HashOutTarget, value: HashOut) { diff --git a/plonky2/src/iop/witness.rs b/plonky2/src/iop/witness.rs index 43dc752d..e1bdf06e 100644 --- a/plonky2/src/iop/witness.rs +++ b/plonky2/src/iop/witness.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use itertools::Itertools; use num::{BigUint, FromPrimitive, Zero}; use plonky2_field::extension_field::{Extendable, FieldExtension}; -use plonky2_field::field_types::Field; +use plonky2_field::field_types::{Field, PrimeField}; use crate::fri::witness_util::set_fri_proof_target; use crate::gadgets::arithmetic_u32::U32Target; @@ -62,20 +62,26 @@ pub trait Witness { panic!("not a bool") } - fn get_biguint_target(&self, target: BigUintTarget) -> BigUint { + fn get_biguint_target(&self, target: BigUintTarget) -> BigUint + where + F: PrimeField, + { let mut result = BigUint::zero(); let limb_base = BigUint::from_u64(1 << 32u64).unwrap(); for i in (0..target.num_limbs()).rev() { let limb = target.get_limb(i); result *= &limb_base; - result += self.get_target(limb.0).to_biguint(); + result += self.get_target(limb.0).to_canonical_biguint(); } result } - fn get_nonnative_target(&self, target: NonNativeTarget) -> FF { + fn get_nonnative_target(&self, target: NonNativeTarget) -> FF + where + F: PrimeField, + { let val = self.get_biguint_target(target.value); FF::from_biguint(val) } diff --git a/plonky2/src/util/serialization.rs b/plonky2/src/util/serialization.rs index adc8baee..d0326073 100644 --- a/plonky2/src/util/serialization.rs +++ b/plonky2/src/util/serialization.rs @@ -3,7 +3,7 @@ use std::io::Cursor; use std::io::{Read, Result, Write}; use plonky2_field::extension_field::{Extendable, FieldExtension}; -use plonky2_field::field_types::Field64; +use plonky2_field::field_types::{Field64, PrimeField64}; use plonky2_field::polynomial::PolynomialCoeffs; use crate::fri::proof::{ @@ -53,7 +53,7 @@ impl Buffer { Ok(u32::from_le_bytes(buf)) } - fn write_field(&mut self, x: F) -> Result<()> { + fn write_field(&mut self, x: F) -> Result<()> { self.0.write_all(&x.to_canonical_u64().to_le_bytes()) } fn read_field(&mut self) -> Result { @@ -116,7 +116,7 @@ impl Buffer { )) } - pub fn write_field_vec(&mut self, v: &[F]) -> Result<()> { + pub fn write_field_vec(&mut self, v: &[F]) -> Result<()> { for &a in v { self.write_field(a)?; } diff --git a/waksman/src/sorting.rs b/waksman/src/sorting.rs index b154436e..286205b1 100644 --- a/waksman/src/sorting.rs +++ b/waksman/src/sorting.rs @@ -183,7 +183,7 @@ impl, const D: usize> SimpleGenerator #[cfg(test)] mod tests { use anyhow::Result; - use plonky2::field::field_types::{Field, Field64}; + use plonky2::field::field_types::{Field, PrimeField64}; use plonky2::iop::witness::PartialWitness; use plonky2::plonk::circuit_data::CircuitConfig; use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; From b2c747b171df954757d64a0fad4482d235e22a1d Mon Sep 17 00:00:00 2001 From: BGluth Date: Wed, 9 Feb 2022 23:34:26 -0700 Subject: [PATCH 05/15] Also did the same to the circuit version - And removed the `debug_assert!`. --- plonky2/src/hash/hash_types.rs | 14 ++++---------- plonky2/src/hash/hashing.rs | 4 ++-- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/plonky2/src/hash/hash_types.rs b/plonky2/src/hash/hash_types.rs index 8187979b..f7605306 100644 --- a/plonky2/src/hash/hash_types.rs +++ b/plonky2/src/hash/hash_types.rs @@ -32,8 +32,6 @@ impl HashOut { } pub fn from_partial(elements_in: &[F]) -> Self { - debug_assert!(elements_in.len() <= 4); - let mut elements = [F::ZERO; 4]; elements[0..elements_in.len()].copy_from_slice(elements_in); Self { elements } @@ -102,14 +100,10 @@ impl HashOutTarget { } } - pub fn from_partial(mut elements: Vec, zero: Target) -> Self { - debug_assert!(elements.len() <= 4); - while elements.len() < 4 { - elements.push(zero); - } - Self { - elements: [elements[0], elements[1], elements[2], elements[3]], - } + pub fn from_partial(elements_in: &[Target], zero: Target) -> Self { + let mut elements = [zero; 4]; + elements[0..elements_in.len()].copy_from_slice(elements_in); + Self { elements } } } diff --git a/plonky2/src/hash/hashing.rs b/plonky2/src/hash/hashing.rs index eb238e51..468bd1b8 100644 --- a/plonky2/src/hash/hashing.rs +++ b/plonky2/src/hash/hashing.rs @@ -12,7 +12,7 @@ pub(crate) const SPONGE_RATE: usize = 8; pub(crate) const SPONGE_CAPACITY: usize = 4; pub const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; -/// Hash the vector if necessary to reduce its length to ~256 bits. If it already fits, this is a +/// Hash the slice if necessary to reduce its length to ~256 bits. If it already fits, this is a /// no-op. pub fn hash_or_noop>(inputs: &[F]) -> HashOut { if inputs.len() <= 4 { @@ -26,7 +26,7 @@ impl, const D: usize> CircuitBuilder { pub fn hash_or_noop>(&mut self, inputs: Vec) -> HashOutTarget { let zero = self.zero(); if inputs.len() <= 4 { - HashOutTarget::from_partial(inputs, zero) + HashOutTarget::from_partial(&inputs, zero) } else { self.hash_n_to_hash_no_pad::(inputs) } From 645d45f227a2c1537529a544f625ede6ca964bc2 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Thu, 10 Feb 2022 12:05:04 -0800 Subject: [PATCH 06/15] Column definitions for addition, range checks & lookups (#477) * Column definitions for addition, range checks & lookups I implemented addition (unsigned for now) as an example of how the arithmetic unit can interact with the 16-bit range check unit. Range checks and lookups aren't implemented yet. * Missing constraints * Tweaks to get tests passing * Reorg registers into files * Minor --- field/src/field_types.rs | 6 + plonky2/src/plonk/plonk_common.rs | 2 +- starky/src/lib.rs | 2 - system_zero/src/arithmetic/addition.rs | 70 ++++++++++++ system_zero/src/arithmetic/division.rs | 31 +++++ system_zero/src/arithmetic/mod.rs | 75 ++++++++++++ system_zero/src/arithmetic/multiplication.rs | 31 +++++ system_zero/src/arithmetic/subtraction.rs | 31 +++++ system_zero/src/column_layout.rs | 108 ------------------ system_zero/src/core_registers.rs | 39 +++++-- system_zero/src/lib.rs | 5 +- system_zero/src/permutation_unit.rs | 11 +- system_zero/src/registers/arithmetic.rs | 37 ++++++ system_zero/src/registers/boolean.rs | 10 ++ system_zero/src/registers/core.rs | 20 ++++ system_zero/src/registers/logic.rs | 3 + system_zero/src/registers/lookup.rs | 21 ++++ system_zero/src/registers/memory.rs | 3 + system_zero/src/registers/mod.rs | 20 ++++ system_zero/src/registers/permutation.rs | 57 +++++++++ system_zero/src/registers/range_check_16.rs | 11 ++ .../src/registers/range_check_degree.rs | 11 ++ system_zero/src/system_zero.rs | 20 +++- 23 files changed, 489 insertions(+), 135 deletions(-) create mode 100644 system_zero/src/arithmetic/addition.rs create mode 100644 system_zero/src/arithmetic/division.rs create mode 100644 system_zero/src/arithmetic/mod.rs create mode 100644 system_zero/src/arithmetic/multiplication.rs create mode 100644 system_zero/src/arithmetic/subtraction.rs delete mode 100644 system_zero/src/column_layout.rs create mode 100644 system_zero/src/registers/arithmetic.rs create mode 100644 system_zero/src/registers/boolean.rs create mode 100644 system_zero/src/registers/core.rs create mode 100644 system_zero/src/registers/logic.rs create mode 100644 system_zero/src/registers/lookup.rs create mode 100644 system_zero/src/registers/memory.rs create mode 100644 system_zero/src/registers/mod.rs create mode 100644 system_zero/src/registers/permutation.rs create mode 100644 system_zero/src/registers/range_check_16.rs create mode 100644 system_zero/src/registers/range_check_degree.rs diff --git a/field/src/field_types.rs b/field/src/field_types.rs index 95696475..83826b9f 100644 --- a/field/src/field_types.rs +++ b/field/src/field_types.rs @@ -278,6 +278,12 @@ pub trait Field: Self::from_canonical_u64(n as u64) } + /// Returns `n`. Assumes that `n` is already in canonical form, i.e. `n < Self::order()`. + // TODO: Should probably be unsafe. + fn from_canonical_u16(n: u16) -> Self { + Self::from_canonical_u64(n as u64) + } + /// Returns `n`. Assumes that `n` is already in canonical form, i.e. `n < Self::order()`. // TODO: Should probably be unsafe. fn from_canonical_usize(n: usize) -> Self { diff --git a/plonky2/src/plonk/plonk_common.rs b/plonky2/src/plonk/plonk_common.rs index 519593b3..09cf2652 100644 --- a/plonky2/src/plonk/plonk_common.rs +++ b/plonky2/src/plonk/plonk_common.rs @@ -138,7 +138,7 @@ where sum } -pub(crate) fn reduce_with_powers_ext_recursive, const D: usize>( +pub fn reduce_with_powers_ext_recursive, const D: usize>( builder: &mut CircuitBuilder, terms: &[ExtensionTarget], alpha: Target, diff --git a/starky/src/lib.rs b/starky/src/lib.rs index dc61e7e7..eefab529 100644 --- a/starky/src/lib.rs +++ b/starky/src/lib.rs @@ -1,8 +1,6 @@ // TODO: Remove these when crate is closer to being finished. #![allow(dead_code)] #![allow(unused_variables)] -#![allow(unreachable_code)] -#![allow(clippy::diverging_sub_expression)] #![allow(incomplete_features)] #![feature(generic_const_exprs)] diff --git a/system_zero/src/arithmetic/addition.rs b/system_zero/src/arithmetic/addition.rs new file mode 100644 index 00000000..653d533b --- /dev/null +++ b/system_zero/src/arithmetic/addition.rs @@ -0,0 +1,70 @@ +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::Field; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::plonk_common::reduce_with_powers_ext_recursive; +use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; + +use crate::registers::arithmetic::*; +use crate::registers::NUM_COLUMNS; + +pub(crate) fn generate_addition(values: &mut [F; NUM_COLUMNS]) { + let in_1 = values[COL_ADD_INPUT_1].to_canonical_u64(); + let in_2 = values[COL_ADD_INPUT_2].to_canonical_u64(); + let in_3 = values[COL_ADD_INPUT_3].to_canonical_u64(); + let output = in_1 + in_2 + in_3; + + values[COL_ADD_OUTPUT_1] = F::from_canonical_u16(output as u16); + values[COL_ADD_OUTPUT_2] = F::from_canonical_u16((output >> 16) as u16); + values[COL_ADD_OUTPUT_3] = F::from_canonical_u16((output >> 32) as u16); +} + +pub(crate) fn eval_addition>( + local_values: &[P; NUM_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + let is_add = local_values[IS_ADD]; + let in_1 = local_values[COL_ADD_INPUT_1]; + let in_2 = local_values[COL_ADD_INPUT_2]; + let in_3 = local_values[COL_ADD_INPUT_3]; + let out_1 = local_values[COL_ADD_OUTPUT_1]; + let out_2 = local_values[COL_ADD_OUTPUT_2]; + let out_3 = local_values[COL_ADD_OUTPUT_3]; + + let weight_2 = F::from_canonical_u64(1 << 16); + let weight_3 = F::from_canonical_u64(1 << 32); + // Note that this can't overflow. Since each output limb has been range checked as 16-bits, + // this sum can be around 48 bits at most. + let out = out_1 + out_2 * weight_2 + out_3 * weight_3; + + let computed_out = in_1 + in_2 + in_3; + + yield_constr.constraint_wrapping(is_add * (out - computed_out)); +} + +pub(crate) fn eval_addition_recursively, const D: usize>( + builder: &mut CircuitBuilder, + local_values: &[ExtensionTarget; NUM_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + let is_add = local_values[IS_ADD]; + let in_1 = local_values[COL_ADD_INPUT_1]; + let in_2 = local_values[COL_ADD_INPUT_2]; + let in_3 = local_values[COL_ADD_INPUT_3]; + let out_1 = local_values[COL_ADD_OUTPUT_1]; + let out_2 = local_values[COL_ADD_OUTPUT_2]; + let out_3 = local_values[COL_ADD_OUTPUT_3]; + + let limb_base = builder.constant(F::from_canonical_u64(1 << 16)); + // Note that this can't overflow. Since each output limb has been range checked as 16-bits, + // this sum can be around 48 bits at most. + let out = reduce_with_powers_ext_recursive(builder, &[out_1, out_2, out_3], limb_base); + + let computed_out = builder.add_many_extension(&[in_1, in_2, in_3]); + + let diff = builder.sub_extension(out, computed_out); + let filtered_diff = builder.mul_extension(is_add, diff); + yield_constr.constraint_wrapping(builder, filtered_diff); +} diff --git a/system_zero/src/arithmetic/division.rs b/system_zero/src/arithmetic/division.rs new file mode 100644 index 00000000..2f15b233 --- /dev/null +++ b/system_zero/src/arithmetic/division.rs @@ -0,0 +1,31 @@ +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::Field; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; + +use crate::registers::arithmetic::*; +use crate::registers::NUM_COLUMNS; + +pub(crate) fn generate_division(values: &mut [F; NUM_COLUMNS]) { + // TODO +} + +pub(crate) fn eval_division>( + local_values: &[P; NUM_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + let is_div = local_values[IS_DIV]; + // TODO +} + +pub(crate) fn eval_division_recursively, const D: usize>( + builder: &mut CircuitBuilder, + local_values: &[ExtensionTarget; NUM_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + let is_div = local_values[IS_DIV]; + // TODO +} diff --git a/system_zero/src/arithmetic/mod.rs b/system_zero/src/arithmetic/mod.rs new file mode 100644 index 00000000..c635d58d --- /dev/null +++ b/system_zero/src/arithmetic/mod.rs @@ -0,0 +1,75 @@ +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::Field; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use starky::vars::StarkEvaluationTargets; +use starky::vars::StarkEvaluationVars; + +use crate::arithmetic::addition::{eval_addition, eval_addition_recursively, generate_addition}; +use crate::arithmetic::division::{eval_division, eval_division_recursively, generate_division}; +use crate::arithmetic::multiplication::{ + eval_multiplication, eval_multiplication_recursively, generate_multiplication, +}; +use crate::arithmetic::subtraction::{ + eval_subtraction, eval_subtraction_recursively, generate_subtraction, +}; +use crate::public_input_layout::NUM_PUBLIC_INPUTS; +use crate::registers::arithmetic::*; +use crate::registers::NUM_COLUMNS; + +mod addition; +mod division; +mod multiplication; +mod subtraction; + +pub(crate) fn generate_arithmetic_unit(values: &mut [F; NUM_COLUMNS]) { + if values[IS_ADD].is_one() { + generate_addition(values); + } else if values[IS_SUB].is_one() { + generate_subtraction(values); + } else if values[IS_MUL].is_one() { + generate_multiplication(values); + } else if values[IS_DIV].is_one() { + generate_division(values); + } +} + +pub(crate) fn eval_arithmetic_unit>( + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, +) { + let local_values = &vars.local_values; + + // Check that the operation flag values are binary. + for col in [IS_ADD, IS_SUB, IS_MUL, IS_DIV] { + let val = local_values[col]; + yield_constr.constraint_wrapping(val * val - val); + } + + eval_addition(local_values, yield_constr); + eval_subtraction(local_values, yield_constr); + eval_multiplication(local_values, yield_constr); + eval_division(local_values, yield_constr); +} + +pub(crate) fn eval_arithmetic_unit_recursively, const D: usize>( + builder: &mut CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let local_values = &vars.local_values; + + // Check that the operation flag values are binary. + for col in [IS_ADD, IS_SUB, IS_MUL, IS_DIV] { + let val = local_values[col]; + let constraint = builder.mul_add_extension(val, val, val); + yield_constr.constraint_wrapping(builder, constraint); + } + + eval_addition_recursively(builder, local_values, yield_constr); + eval_subtraction_recursively(builder, local_values, yield_constr); + eval_multiplication_recursively(builder, local_values, yield_constr); + eval_division_recursively(builder, local_values, yield_constr); +} diff --git a/system_zero/src/arithmetic/multiplication.rs b/system_zero/src/arithmetic/multiplication.rs new file mode 100644 index 00000000..2eefad38 --- /dev/null +++ b/system_zero/src/arithmetic/multiplication.rs @@ -0,0 +1,31 @@ +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::Field; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; + +use crate::registers::arithmetic::*; +use crate::registers::NUM_COLUMNS; + +pub(crate) fn generate_multiplication(values: &mut [F; NUM_COLUMNS]) { + // TODO +} + +pub(crate) fn eval_multiplication>( + local_values: &[P; NUM_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + let is_mul = local_values[IS_MUL]; + // TODO +} + +pub(crate) fn eval_multiplication_recursively, const D: usize>( + builder: &mut CircuitBuilder, + local_values: &[ExtensionTarget; NUM_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + let is_mul = local_values[IS_MUL]; + // TODO +} diff --git a/system_zero/src/arithmetic/subtraction.rs b/system_zero/src/arithmetic/subtraction.rs new file mode 100644 index 00000000..3613dee6 --- /dev/null +++ b/system_zero/src/arithmetic/subtraction.rs @@ -0,0 +1,31 @@ +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::Field; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; + +use crate::registers::arithmetic::*; +use crate::registers::NUM_COLUMNS; + +pub(crate) fn generate_subtraction(values: &mut [F; NUM_COLUMNS]) { + // TODO +} + +pub(crate) fn eval_subtraction>( + local_values: &[P; NUM_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + let is_sub = local_values[IS_SUB]; + // TODO +} + +pub(crate) fn eval_subtraction_recursively, const D: usize>( + builder: &mut CircuitBuilder, + local_values: &[ExtensionTarget; NUM_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + let is_sub = local_values[IS_SUB]; + // TODO +} diff --git a/system_zero/src/column_layout.rs b/system_zero/src/column_layout.rs deleted file mode 100644 index fa5d627a..00000000 --- a/system_zero/src/column_layout.rs +++ /dev/null @@ -1,108 +0,0 @@ -//// CORE REGISTERS - -/// A cycle counter. Starts at 0; increments by 1. -pub(crate) const COL_CLOCK: usize = 0; - -/// A column which contains the values `[0, ... 2^16 - 1]`, potentially with duplicates. Used for -/// 16-bit range checks. -/// -/// For ease of verification, we enforce that it must begin with 0 and end with `2^16 - 1`, and each -/// delta must be either 0 or 1. -pub(crate) const COL_RANGE_16: usize = COL_CLOCK + 1; - -/// Pointer to the current instruction. -pub(crate) const COL_INSTRUCTION_PTR: usize = COL_RANGE_16 + 1; -/// Pointer to the base of the current call's stack frame. -pub(crate) const COL_FRAME_PTR: usize = COL_INSTRUCTION_PTR + 1; -/// Pointer to the tip of the current call's stack frame. -pub(crate) const COL_STACK_PTR: usize = COL_FRAME_PTR + 1; - -//// PERMUTATION UNIT -pub(crate) mod permutation { - use plonky2::hash::hashing::SPONGE_WIDTH; - use plonky2::hash::poseidon; - - const START_UNIT: usize = super::COL_STACK_PTR + 1; - - const START_FULL_FIRST: usize = START_UNIT + SPONGE_WIDTH; - - pub const fn col_full_first_mid_sbox(round: usize, i: usize) -> usize { - debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); - debug_assert!(i < SPONGE_WIDTH); - START_FULL_FIRST + 2 * round * SPONGE_WIDTH + i - } - - pub const fn col_full_first_after_mds(round: usize, i: usize) -> usize { - debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); - debug_assert!(i < SPONGE_WIDTH); - START_FULL_FIRST + (2 * round + 1) * SPONGE_WIDTH + i - } - - const START_PARTIAL: usize = - col_full_first_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, SPONGE_WIDTH - 1) + 1; - - pub const fn col_partial_mid_sbox(round: usize) -> usize { - debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); - START_PARTIAL + 2 * round - } - - pub const fn col_partial_after_sbox(round: usize) -> usize { - debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); - START_PARTIAL + 2 * round + 1 - } - - const START_FULL_SECOND: usize = col_partial_after_sbox(poseidon::N_PARTIAL_ROUNDS - 1) + 1; - - pub const fn col_full_second_mid_sbox(round: usize, i: usize) -> usize { - debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS); - debug_assert!(i < SPONGE_WIDTH); - START_FULL_SECOND + 2 * round * SPONGE_WIDTH + i - } - - pub const fn col_full_second_after_mds(round: usize, i: usize) -> usize { - debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS); - debug_assert!(i < SPONGE_WIDTH); - START_FULL_SECOND + (2 * round + 1) * SPONGE_WIDTH + i - } - - pub const fn col_input(i: usize) -> usize { - debug_assert!(i < SPONGE_WIDTH); - START_UNIT + i - } - - pub const fn col_output(i: usize) -> usize { - debug_assert!(i < SPONGE_WIDTH); - col_full_second_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, i) - } - - pub(super) const END_UNIT: usize = col_output(SPONGE_WIDTH - 1); -} - -//// MEMORY UNITS - -//// DECOMPOSITION UNITS -pub(crate) mod decomposition { - - const START_UNITS: usize = super::permutation::END_UNIT + 1; - - const NUM_UNITS: usize = 4; - /// The number of bits associated with a single decomposition unit. - const UNIT_BITS: usize = 32; - /// One column for the value being decomposed, plus one column per bit. - const UNIT_COLS: usize = 1 + UNIT_BITS; - - pub const fn col_input(unit: usize) -> usize { - debug_assert!(unit < NUM_UNITS); - START_UNITS + unit * UNIT_COLS - } - - pub const fn col_bit(unit: usize, bit: usize) -> usize { - debug_assert!(unit < NUM_UNITS); - debug_assert!(bit < UNIT_BITS); - START_UNITS + unit * UNIT_COLS + 1 + bit - } - - pub(super) const END_UNITS: usize = START_UNITS + UNIT_COLS * NUM_UNITS; -} - -pub(crate) const NUM_COLUMNS: usize = decomposition::END_UNITS; diff --git a/system_zero/src/core_registers.rs b/system_zero/src/core_registers.rs index 21faa288..03e7fa04 100644 --- a/system_zero/src/core_registers.rs +++ b/system_zero/src/core_registers.rs @@ -6,10 +6,9 @@ use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsume use starky::vars::StarkEvaluationTargets; use starky::vars::StarkEvaluationVars; -use crate::column_layout::{ - COL_CLOCK, COL_FRAME_PTR, COL_INSTRUCTION_PTR, COL_RANGE_16, COL_STACK_PTR, NUM_COLUMNS, -}; use crate::public_input_layout::NUM_PUBLIC_INPUTS; +use crate::registers::core::*; +use crate::registers::NUM_COLUMNS; use crate::system_zero::SystemZero; impl, const D: usize> SystemZero { @@ -35,11 +34,11 @@ impl, const D: usize> SystemZero { let next_range_16 = (prev_range_16 + 1).min((1 << 16) - 1); next_values[COL_RANGE_16] = F::from_canonical_u64(next_range_16); - next_values[COL_INSTRUCTION_PTR] = todo!(); + // next_values[COL_INSTRUCTION_PTR] = todo!(); - next_values[COL_FRAME_PTR] = todo!(); + // next_values[COL_FRAME_PTR] = todo!(); - next_values[COL_STACK_PTR] = todo!(); + // next_values[COL_STACK_PTR] = todo!(); } #[inline] @@ -64,9 +63,9 @@ impl, const D: usize> SystemZero { let delta_range_16 = next_range_16 - local_range_16; yield_constr.constraint_first_row(local_range_16); yield_constr.constraint_last_row(local_range_16 - FE::from_canonical_u64((1 << 16) - 1)); - yield_constr.constraint(delta_range_16 * (delta_range_16 - FE::ONE)); + yield_constr.constraint(delta_range_16 * delta_range_16 - delta_range_16); - todo!() + // TODO constraints for stack etc. } pub(crate) fn eval_core_registers_recursively( @@ -75,6 +74,28 @@ impl, const D: usize> SystemZero { vars: StarkEvaluationTargets, yield_constr: &mut RecursiveConstraintConsumer, ) { - todo!() + let one_ext = builder.one_extension(); + let max_u16 = builder.constant(F::from_canonical_u64((1 << 16) - 1)); + let max_u16_ext = builder.convert_to_ext(max_u16); + + // The clock must start with 0, and increment by 1. + let local_clock = vars.local_values[COL_CLOCK]; + let next_clock = vars.next_values[COL_CLOCK]; + let delta_clock = builder.sub_extension(next_clock, local_clock); + yield_constr.constraint_first_row(builder, local_clock); + let constraint = builder.sub_extension(delta_clock, one_ext); + yield_constr.constraint(builder, constraint); + + // The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1. + let local_range_16 = vars.local_values[COL_RANGE_16]; + let next_range_16 = vars.next_values[COL_RANGE_16]; + let delta_range_16 = builder.sub_extension(next_range_16, local_range_16); + yield_constr.constraint_first_row(builder, local_range_16); + let constraint = builder.sub_extension(local_range_16, max_u16_ext); + yield_constr.constraint_last_row(builder, constraint); + let constraint = builder.mul_add_extension(delta_range_16, delta_range_16, delta_range_16); + yield_constr.constraint(builder, constraint); + + // TODO constraints for stack etc. } } diff --git a/system_zero/src/lib.rs b/system_zero/src/lib.rs index 029c2abd..1c097573 100644 --- a/system_zero/src/lib.rs +++ b/system_zero/src/lib.rs @@ -1,12 +1,11 @@ // TODO: Remove these when crate is closer to being finished. #![allow(dead_code)] #![allow(unused_variables)] -#![allow(unreachable_code)] -#![allow(clippy::diverging_sub_expression)] -mod column_layout; +mod arithmetic; mod core_registers; mod memory; mod permutation_unit; mod public_input_layout; +mod registers; pub mod system_zero; diff --git a/system_zero/src/permutation_unit.rs b/system_zero/src/permutation_unit.rs index e15474e4..2681f2d9 100644 --- a/system_zero/src/permutation_unit.rs +++ b/system_zero/src/permutation_unit.rs @@ -8,12 +8,9 @@ use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsume use starky::vars::StarkEvaluationTargets; use starky::vars::StarkEvaluationVars; -use crate::column_layout::permutation::{ - col_full_first_after_mds, col_full_first_mid_sbox, col_full_second_after_mds, - col_full_second_mid_sbox, col_input, col_partial_after_sbox, col_partial_mid_sbox, -}; -use crate::column_layout::NUM_COLUMNS; use crate::public_input_layout::NUM_PUBLIC_INPUTS; +use crate::registers::permutation::*; +use crate::registers::NUM_COLUMNS; use crate::system_zero::SystemZero; fn constant_layer( @@ -272,10 +269,10 @@ mod tests { use starky::constraint_consumer::ConstraintConsumer; use starky::vars::StarkEvaluationVars; - use crate::column_layout::permutation::{col_input, col_output}; - use crate::column_layout::NUM_COLUMNS; use crate::permutation_unit::SPONGE_WIDTH; use crate::public_input_layout::NUM_PUBLIC_INPUTS; + use crate::registers::permutation::{col_input, col_output}; + use crate::registers::NUM_COLUMNS; use crate::system_zero::SystemZero; #[test] diff --git a/system_zero/src/registers/arithmetic.rs b/system_zero/src/registers/arithmetic.rs new file mode 100644 index 00000000..92c0d2c3 --- /dev/null +++ b/system_zero/src/registers/arithmetic.rs @@ -0,0 +1,37 @@ +//! Arithmetic unit. + +pub(crate) const IS_ADD: usize = super::START_ARITHMETIC; +pub(crate) const IS_SUB: usize = IS_ADD + 1; +pub(crate) const IS_MUL: usize = IS_SUB + 1; +pub(crate) const IS_DIV: usize = IS_MUL + 1; + +const START_SHARED_COLS: usize = IS_DIV + 1; + +/// Within the arithmetic unit, there are shared columns which can be used by any arithmetic +/// circuit, depending on which one is active this cycle. +// Can be increased as needed as other operations are implemented. +const NUM_SHARED_COLS: usize = 3; + +const fn shared_col(i: usize) -> usize { + debug_assert!(i < NUM_SHARED_COLS); + START_SHARED_COLS + i +} + +/// The first value to be added; treated as an unsigned u32. +pub(crate) const COL_ADD_INPUT_1: usize = shared_col(0); +/// The second value to be added; treated as an unsigned u32. +pub(crate) const COL_ADD_INPUT_2: usize = shared_col(1); +/// The third value to be added; treated as an unsigned u32. +pub(crate) const COL_ADD_INPUT_3: usize = shared_col(2); + +// Note: Addition outputs three 16-bit chunks, and since these values need to be range-checked +// anyway, we might as well use the range check unit's columns as our addition outputs. So the +// three proceeding columns are basically aliases, not columns owned by the arithmetic unit. +/// The first 16-bit chunk of the output, based on little-endian ordering. +pub(crate) const COL_ADD_OUTPUT_1: usize = super::range_check_16::col_rc_16_input(0); +/// The second 16-bit chunk of the output, based on little-endian ordering. +pub(crate) const COL_ADD_OUTPUT_2: usize = super::range_check_16::col_rc_16_input(1); +/// The third 16-bit chunk of the output, based on little-endian ordering. +pub(crate) const COL_ADD_OUTPUT_3: usize = super::range_check_16::col_rc_16_input(2); + +pub(super) const END: usize = super::START_ARITHMETIC + NUM_SHARED_COLS; diff --git a/system_zero/src/registers/boolean.rs b/system_zero/src/registers/boolean.rs new file mode 100644 index 00000000..c59af8d4 --- /dev/null +++ b/system_zero/src/registers/boolean.rs @@ -0,0 +1,10 @@ +//! Boolean unit. Contains columns whose values must be 0 or 1. + +const NUM_BITS: usize = 128; + +pub const fn col_bit(index: usize) -> usize { + debug_assert!(index < NUM_BITS); + super::START_BOOLEAN + index +} + +pub(super) const END: usize = super::START_BOOLEAN + NUM_BITS; diff --git a/system_zero/src/registers/core.rs b/system_zero/src/registers/core.rs new file mode 100644 index 00000000..3fafab55 --- /dev/null +++ b/system_zero/src/registers/core.rs @@ -0,0 +1,20 @@ +//! Core registers. + +/// A cycle counter. Starts at 0; increments by 1. +pub(crate) const COL_CLOCK: usize = super::START_CORE; + +/// A column which contains the values `[0, ... 2^16 - 1]`, potentially with duplicates. Used for +/// 16-bit range checks. +/// +/// For ease of verification, we enforce that it must begin with 0 and end with `2^16 - 1`, and each +/// delta must be either 0 or 1. +pub(crate) const COL_RANGE_16: usize = COL_CLOCK + 1; + +/// Pointer to the current instruction. +pub(crate) const COL_INSTRUCTION_PTR: usize = COL_RANGE_16 + 1; +/// Pointer to the base of the current call's stack frame. +pub(crate) const COL_FRAME_PTR: usize = COL_INSTRUCTION_PTR + 1; +/// Pointer to the tip of the current call's stack frame. +pub(crate) const COL_STACK_PTR: usize = COL_FRAME_PTR + 1; + +pub(super) const END: usize = COL_STACK_PTR + 1; diff --git a/system_zero/src/registers/logic.rs b/system_zero/src/registers/logic.rs new file mode 100644 index 00000000..07f3f0e0 --- /dev/null +++ b/system_zero/src/registers/logic.rs @@ -0,0 +1,3 @@ +//! Logic unit. + +pub(super) const END: usize = super::START_LOGIC; diff --git a/system_zero/src/registers/lookup.rs b/system_zero/src/registers/lookup.rs new file mode 100644 index 00000000..eb773acf --- /dev/null +++ b/system_zero/src/registers/lookup.rs @@ -0,0 +1,21 @@ +//! Lookup unit. +//! See https://zcash.github.io/halo2/design/proving-system/lookup.html + +const START_UNIT: usize = super::START_LOOKUP; + +const NUM_LOOKUPS: usize = + super::range_check_16::NUM_RANGE_CHECKS + super::range_check_degree::NUM_RANGE_CHECKS; + +/// This column contains a permutation of the input values. +const fn col_permuted_input(i: usize) -> usize { + debug_assert!(i < NUM_LOOKUPS); + START_UNIT + 2 * i +} + +/// This column contains a permutation of the table values. +const fn col_permuted_table(i: usize) -> usize { + debug_assert!(i < NUM_LOOKUPS); + START_UNIT + 2 * i + 1 +} + +pub(super) const END: usize = START_UNIT + NUM_LOOKUPS; diff --git a/system_zero/src/registers/memory.rs b/system_zero/src/registers/memory.rs new file mode 100644 index 00000000..1373d0d8 --- /dev/null +++ b/system_zero/src/registers/memory.rs @@ -0,0 +1,3 @@ +//! Memory unit. + +pub(super) const END: usize = super::START_MEMORY; diff --git a/system_zero/src/registers/mod.rs b/system_zero/src/registers/mod.rs new file mode 100644 index 00000000..134a28bf --- /dev/null +++ b/system_zero/src/registers/mod.rs @@ -0,0 +1,20 @@ +pub(crate) mod arithmetic; +pub(crate) mod boolean; +pub(crate) mod core; +pub(crate) mod logic; +pub(crate) mod lookup; +pub(crate) mod memory; +pub(crate) mod permutation; +pub(crate) mod range_check_16; +pub(crate) mod range_check_degree; + +const START_ARITHMETIC: usize = 0; +const START_BOOLEAN: usize = arithmetic::END; +const START_CORE: usize = boolean::END; +const START_LOGIC: usize = core::END; +const START_LOOKUP: usize = logic::END; +const START_MEMORY: usize = lookup::END; +const START_PERMUTATION: usize = memory::END; +const START_RANGE_CHECK_16: usize = permutation::END; +const START_RANGE_CHECK_DEGREE: usize = range_check_16::END; +pub(crate) const NUM_COLUMNS: usize = range_check_degree::END; diff --git a/system_zero/src/registers/permutation.rs b/system_zero/src/registers/permutation.rs new file mode 100644 index 00000000..cde76af2 --- /dev/null +++ b/system_zero/src/registers/permutation.rs @@ -0,0 +1,57 @@ +//! Permutation unit. + +use plonky2::hash::hashing::SPONGE_WIDTH; +use plonky2::hash::poseidon; + +const START_FULL_FIRST: usize = super::START_PERMUTATION + SPONGE_WIDTH; + +pub const fn col_full_first_mid_sbox(round: usize, i: usize) -> usize { + debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < SPONGE_WIDTH); + START_FULL_FIRST + 2 * round * SPONGE_WIDTH + i +} + +pub const fn col_full_first_after_mds(round: usize, i: usize) -> usize { + debug_assert!(round < poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < SPONGE_WIDTH); + START_FULL_FIRST + (2 * round + 1) * SPONGE_WIDTH + i +} + +const START_PARTIAL: usize = + col_full_first_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, SPONGE_WIDTH - 1) + 1; + +pub const fn col_partial_mid_sbox(round: usize) -> usize { + debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); + START_PARTIAL + 2 * round +} + +pub const fn col_partial_after_sbox(round: usize) -> usize { + debug_assert!(round < poseidon::N_PARTIAL_ROUNDS); + START_PARTIAL + 2 * round + 1 +} + +const START_FULL_SECOND: usize = col_partial_after_sbox(poseidon::N_PARTIAL_ROUNDS - 1) + 1; + +pub const fn col_full_second_mid_sbox(round: usize, i: usize) -> usize { + debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < SPONGE_WIDTH); + START_FULL_SECOND + 2 * round * SPONGE_WIDTH + i +} + +pub const fn col_full_second_after_mds(round: usize, i: usize) -> usize { + debug_assert!(round <= poseidon::HALF_N_FULL_ROUNDS); + debug_assert!(i < SPONGE_WIDTH); + START_FULL_SECOND + (2 * round + 1) * SPONGE_WIDTH + i +} + +pub const fn col_input(i: usize) -> usize { + debug_assert!(i < SPONGE_WIDTH); + super::START_PERMUTATION + i +} + +pub const fn col_output(i: usize) -> usize { + debug_assert!(i < SPONGE_WIDTH); + col_full_second_after_mds(poseidon::HALF_N_FULL_ROUNDS - 1, i) +} + +pub(super) const END: usize = col_output(SPONGE_WIDTH - 1) + 1; diff --git a/system_zero/src/registers/range_check_16.rs b/system_zero/src/registers/range_check_16.rs new file mode 100644 index 00000000..c44db494 --- /dev/null +++ b/system_zero/src/registers/range_check_16.rs @@ -0,0 +1,11 @@ +//! Range check unit which checks that values are in `[0, 2^16)`. + +pub(super) const NUM_RANGE_CHECKS: usize = 5; + +/// The input of the `i`th range check, i.e. the value being range checked. +pub(crate) const fn col_rc_16_input(i: usize) -> usize { + debug_assert!(i < NUM_RANGE_CHECKS); + super::START_RANGE_CHECK_16 + i +} + +pub(super) const END: usize = super::START_RANGE_CHECK_16 + NUM_RANGE_CHECKS; diff --git a/system_zero/src/registers/range_check_degree.rs b/system_zero/src/registers/range_check_degree.rs new file mode 100644 index 00000000..6d61e6e2 --- /dev/null +++ b/system_zero/src/registers/range_check_degree.rs @@ -0,0 +1,11 @@ +//! Range check unit which checks that values are in `[0, degree)`. + +pub(super) const NUM_RANGE_CHECKS: usize = 5; + +/// The input of the `i`th range check, i.e. the value being range checked. +pub(crate) const fn col_rc_degree_input(i: usize) -> usize { + debug_assert!(i < NUM_RANGE_CHECKS); + super::START_RANGE_CHECK_DEGREE + i +} + +pub(super) const END: usize = super::START_RANGE_CHECK_DEGREE + NUM_RANGE_CHECKS; diff --git a/system_zero/src/system_zero.rs b/system_zero/src/system_zero.rs index 70d3bbca..780b1d38 100644 --- a/system_zero/src/system_zero.rs +++ b/system_zero/src/system_zero.rs @@ -9,9 +9,12 @@ use starky::stark::Stark; use starky::vars::StarkEvaluationTargets; use starky::vars::StarkEvaluationVars; -use crate::column_layout::NUM_COLUMNS; +use crate::arithmetic::{ + eval_arithmetic_unit, eval_arithmetic_unit_recursively, generate_arithmetic_unit, +}; use crate::memory::TransactionMemory; use crate::public_input_layout::NUM_PUBLIC_INPUTS; +use crate::registers::NUM_COLUMNS; /// We require at least 2^16 rows as it helps support efficient 16-bit range checks. const MIN_TRACE_ROWS: usize = 1 << 16; @@ -34,10 +37,16 @@ impl, const D: usize> SystemZero { loop { let mut next_row = [F::ZERO; NUM_COLUMNS]; self.generate_next_row_core_registers(&row, &mut next_row); + generate_arithmetic_unit(&mut next_row); Self::generate_permutation_unit(&mut next_row); trace.push(row); row = next_row; + + // TODO: Replace with proper termination condition. + if trace.len() == (1 << 16) - 1 { + break; + } } trace.push(row); @@ -66,8 +75,9 @@ impl, const D: usize> Stark for SystemZero, { self.eval_core_registers(vars, yield_constr); + eval_arithmetic_unit(vars, yield_constr); Self::eval_permutation_unit(vars, yield_constr); - todo!() + // TODO: Other units } fn eval_ext_recursively( @@ -77,8 +87,9 @@ impl, const D: usize> Stark for SystemZero, ) { self.eval_core_registers_recursively(builder, vars, yield_constr); + eval_arithmetic_unit_recursively(builder, vars, yield_constr); Self::eval_permutation_unit_recursively(builder, vars, yield_constr); - todo!() + // TODO: Other units } fn constraint_degree(&self) -> usize { @@ -103,7 +114,7 @@ mod tests { use crate::system_zero::SystemZero; #[test] - #[ignore] // TODO + #[ignore] // A bit slow. fn run() -> Result<()> { type F = GoldilocksField; type C = PoseidonGoldilocksConfig; @@ -121,7 +132,6 @@ mod tests { } #[test] - #[ignore] // TODO fn degree() -> Result<()> { type F = GoldilocksField; type C = PoseidonGoldilocksConfig; From 7c71eb66908260f163938dc1ad4b1b7851893aed Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 11 Feb 2022 10:25:51 +0100 Subject: [PATCH 07/15] Fix mul_add -> mul_sub typo --- system_zero/src/arithmetic/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/system_zero/src/arithmetic/mod.rs b/system_zero/src/arithmetic/mod.rs index c635d58d..45a9f7d9 100644 --- a/system_zero/src/arithmetic/mod.rs +++ b/system_zero/src/arithmetic/mod.rs @@ -64,7 +64,7 @@ pub(crate) fn eval_arithmetic_unit_recursively, con // Check that the operation flag values are binary. for col in [IS_ADD, IS_SUB, IS_MUL, IS_DIV] { let val = local_values[col]; - let constraint = builder.mul_add_extension(val, val, val); + let constraint = builder.mul_sub_extension(val, val, val); yield_constr.constraint_wrapping(builder, constraint); } From 1d013b95ddfb02519c75cc8d5e3f64684a79b269 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 11 Feb 2022 16:22:57 +0100 Subject: [PATCH 08/15] Fix `hash_or_noop` in Merkle proof. --- plonky2/src/hash/merkle_proofs.rs | 2 +- plonky2/src/plonk/config.rs | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/plonky2/src/hash/merkle_proofs.rs b/plonky2/src/hash/merkle_proofs.rs index f90f0657..7ef81570 100644 --- a/plonky2/src/hash/merkle_proofs.rs +++ b/plonky2/src/hash/merkle_proofs.rs @@ -32,7 +32,7 @@ pub(crate) fn verify_merkle_proof>( proof: &MerkleProof, ) -> Result<()> { let mut index = leaf_index; - let mut current_digest = H::hash_no_pad(&leaf_data); + let mut current_digest = H::hash_or_noop(&leaf_data); for &sibling_digest in proof.siblings.iter() { let bit = index & 1; index >>= 1; diff --git a/plonky2/src/plonk/config.rs b/plonky2/src/plonk/config.rs index fdca7037..76891240 100644 --- a/plonky2/src/plonk/config.rs +++ b/plonky2/src/plonk/config.rs @@ -46,6 +46,17 @@ pub trait Hasher: Sized + Clone + Debug + Eq + PartialEq { Self::hash_no_pad(&padded_input) } + /// Hash the slice if necessary to reduce its length to ~256 bits. If it already fits, this is a + /// no-op. + fn hash_or_noop(inputs: &[F]) -> Self::Hash { + if inputs.len() <= 4 { + let inputs_bytes = HashOut::from_partial(inputs).to_bytes(); + Self::Hash::from_bytes(&inputs_bytes) + } else { + Self::hash_no_pad(inputs) + } + } + fn two_to_one(left: Self::Hash, right: Self::Hash) -> Self::Hash; } From f7256a6efc361d206879b65c5240ba7fe25d7a3c Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Fri, 11 Feb 2022 16:41:44 +0100 Subject: [PATCH 09/15] Other fixes --- plonky2/src/hash/merkle_tree.rs | 4 ++-- plonky2/src/hash/path_compression.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/plonky2/src/hash/merkle_tree.rs b/plonky2/src/hash/merkle_tree.rs index e9460c14..f9890aa5 100644 --- a/plonky2/src/hash/merkle_tree.rs +++ b/plonky2/src/hash/merkle_tree.rs @@ -63,7 +63,7 @@ fn fill_subtree>( ) -> H::Hash { assert_eq!(leaves.len(), digests_buf.len() / 2 + 1); if digests_buf.is_empty() { - H::hash_no_pad(&leaves[0]) + H::hash_or_noop(&leaves[0]) } else { // Layout is: left recursive output || left child digest // || right child digest || right recursive output. @@ -99,7 +99,7 @@ fn fill_digests_buf>( .par_iter_mut() .zip(leaves) .for_each(|(cap_buf, leaf)| { - cap_buf.write(H::hash_no_pad(leaf)); + cap_buf.write(H::hash_or_noop(leaf)); }); return; } diff --git a/plonky2/src/hash/path_compression.rs b/plonky2/src/hash/path_compression.rs index 56c355fd..fe7850f4 100644 --- a/plonky2/src/hash/path_compression.rs +++ b/plonky2/src/hash/path_compression.rs @@ -66,7 +66,7 @@ pub(crate) fn decompress_merkle_proofs>( for (&i, v) in leaves_indices.iter().zip(leaves_data) { // Observe the leaves. - seen.insert(i + num_leaves, H::hash_no_pad(v)); + seen.insert(i + num_leaves, H::hash_or_noop(v)); } // Iterators over the siblings. From 736b65b0a7d595b0e1417bd08607edcda859e548 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Sat, 12 Feb 2022 15:18:20 +0100 Subject: [PATCH 10/15] PR feedback --- plonky2/src/hash/hashing.rs | 10 ---------- plonky2/src/plonk/config.rs | 9 +++++++-- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/plonky2/src/hash/hashing.rs b/plonky2/src/hash/hashing.rs index 468bd1b8..9d043ea3 100644 --- a/plonky2/src/hash/hashing.rs +++ b/plonky2/src/hash/hashing.rs @@ -12,16 +12,6 @@ pub(crate) const SPONGE_RATE: usize = 8; pub(crate) const SPONGE_CAPACITY: usize = 4; pub const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; -/// Hash the slice if necessary to reduce its length to ~256 bits. If it already fits, this is a -/// no-op. -pub fn hash_or_noop>(inputs: &[F]) -> HashOut { - if inputs.len() <= 4 { - HashOut::from_partial(inputs) - } else { - hash_n_to_hash_no_pad::(inputs) - } -} - impl, const D: usize> CircuitBuilder { pub fn hash_or_noop>(&mut self, inputs: Vec) -> HashOutTarget { let zero = self.zero(); diff --git a/plonky2/src/plonk/config.rs b/plonky2/src/plonk/config.rs index 76891240..40179c38 100644 --- a/plonky2/src/plonk/config.rs +++ b/plonky2/src/plonk/config.rs @@ -3,6 +3,7 @@ use std::fmt::Debug; use plonky2_field::extension_field::quadratic::QuadraticExtension; use plonky2_field::extension_field::{Extendable, FieldExtension}; use plonky2_field::goldilocks_field::GoldilocksField; +use plonky2_util::ceil_div_usize; use serde::{de::DeserializeOwned, Serialize}; use crate::hash::hash_types::HashOut; @@ -49,8 +50,12 @@ pub trait Hasher: Sized + Clone + Debug + Eq + PartialEq { /// Hash the slice if necessary to reduce its length to ~256 bits. If it already fits, this is a /// no-op. fn hash_or_noop(inputs: &[F]) -> Self::Hash { - if inputs.len() <= 4 { - let inputs_bytes = HashOut::from_partial(inputs).to_bytes(); + if inputs.len() * ceil_div_usize(F::BITS, 8) <= Self::HASH_SIZE { + let mut inputs_bytes = inputs + .iter() + .flat_map(|x| x.to_canonical_u64().to_le_bytes()) + .collect::>(); + inputs_bytes.resize(Self::HASH_SIZE, 0); Self::Hash::from_bytes(&inputs_bytes) } else { Self::hash_no_pad(inputs) From 7af2d05828240123e70f108dc0baf67a5338788c Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Sun, 13 Feb 2022 15:04:40 +0100 Subject: [PATCH 11/15] Save allocation and add const generic bound --- plonky2/benches/merkle.rs | 7 +++++- plonky2/src/fri/oracle.rs | 17 ++++++++++---- plonky2/src/fri/proof.rs | 5 ++++- plonky2/src/fri/prover.rs | 10 +++++++-- plonky2/src/fri/verifier.rs | 21 ++++++++++------- plonky2/src/gates/gate_testing.rs | 7 ++++-- plonky2/src/hash/merkle_proofs.rs | 5 ++++- plonky2/src/hash/merkle_tree.rs | 25 +++++++++++++-------- plonky2/src/hash/path_compression.rs | 5 ++++- plonky2/src/plonk/circuit_builder.rs | 15 ++++++++++--- plonky2/src/plonk/circuit_data.rs | 30 ++++++++++++++++++++----- plonky2/src/plonk/config.rs | 18 ++++++++------- plonky2/src/plonk/proof.rs | 15 ++++++++++--- plonky2/src/plonk/prover.rs | 5 ++++- plonky2/src/plonk/recursive_verifier.rs | 15 ++++++++++--- plonky2/src/plonk/verifier.rs | 10 +++++++-- starky/src/prover.rs | 3 ++- starky/src/verifier.rs | 4 +++- 18 files changed, 160 insertions(+), 57 deletions(-) diff --git a/plonky2/benches/merkle.rs b/plonky2/benches/merkle.rs index 7445682b..8bc43730 100644 --- a/plonky2/benches/merkle.rs +++ b/plonky2/benches/merkle.rs @@ -1,3 +1,5 @@ +#![feature(generic_const_exprs)] + use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::hash::hash_types::RichField; @@ -9,7 +11,10 @@ use tynm::type_name; const ELEMS_PER_LEAF: usize = 135; -pub(crate) fn bench_merkle_tree>(c: &mut Criterion) { +pub(crate) fn bench_merkle_tree>(c: &mut Criterion) +where + [(); H::HASH_SIZE]:, +{ let mut group = c.benchmark_group(&format!( "merkle-tree<{}, {}>", type_name::(), diff --git a/plonky2/src/fri/oracle.rs b/plonky2/src/fri/oracle.rs index 0922962a..bd1e9ac5 100644 --- a/plonky2/src/fri/oracle.rs +++ b/plonky2/src/fri/oracle.rs @@ -12,7 +12,7 @@ use crate::fri::FriParams; use crate::hash::hash_types::RichField; use crate::hash::merkle_tree::MerkleTree; use crate::iop::challenger::Challenger; -use crate::plonk::config::GenericConfig; +use crate::plonk::config::{GenericConfig, Hasher}; use crate::timed; use crate::util::reducing::ReducingFactor; use crate::util::reverse_bits; @@ -43,7 +43,10 @@ impl, C: GenericConfig, const D: usize> cap_height: usize, timing: &mut TimingTree, fft_root_table: Option<&FftRootTable>, - ) -> Self { + ) -> Self + where + [(); C::Hasher::HASH_SIZE]:, + { let coeffs = timed!( timing, "IFFT", @@ -68,7 +71,10 @@ impl, C: GenericConfig, const D: usize> cap_height: usize, timing: &mut TimingTree, fft_root_table: Option<&FftRootTable>, - ) -> Self { + ) -> Self + where + [(); C::Hasher::HASH_SIZE]:, + { let degree = polynomials[0].len(); let lde_values = timed!( timing, @@ -133,7 +139,10 @@ impl, C: GenericConfig, const D: usize> challenger: &mut Challenger, fri_params: &FriParams, timing: &mut TimingTree, - ) -> FriProof { + ) -> FriProof + where + [(); C::Hasher::HASH_SIZE]:, + { assert!(D > 1, "Not implemented for D=1."); let alpha = challenger.get_extension_challenge::(); let mut alpha = ReducingFactor::new(alpha); diff --git a/plonky2/src/fri/proof.rs b/plonky2/src/fri/proof.rs index 44f74cba..9c6961a4 100644 --- a/plonky2/src/fri/proof.rs +++ b/plonky2/src/fri/proof.rs @@ -245,7 +245,10 @@ impl, H: Hasher, const D: usize> CompressedFriPr challenges: &ProofChallenges, fri_inferred_elements: FriInferredElements, params: &FriParams, - ) -> FriProof { + ) -> FriProof + where + [(); H::HASH_SIZE]:, + { let CompressedFriProof { commit_phase_merkle_caps, query_round_proofs, diff --git a/plonky2/src/fri/prover.rs b/plonky2/src/fri/prover.rs index 5cd5fdf1..5a20ab9d 100644 --- a/plonky2/src/fri/prover.rs +++ b/plonky2/src/fri/prover.rs @@ -24,7 +24,10 @@ pub fn fri_proof, C: GenericConfig, const challenger: &mut Challenger, fri_params: &FriParams, timing: &mut TimingTree, -) -> FriProof { +) -> FriProof +where + [(); C::Hasher::HASH_SIZE]:, +{ let n = lde_polynomial_values.len(); assert_eq!(lde_polynomial_coeffs.len(), n); @@ -68,7 +71,10 @@ fn fri_committed_trees, C: GenericConfig, ) -> ( Vec>, PolynomialCoeffs, -) { +) +where + [(); C::Hasher::HASH_SIZE]:, +{ let mut trees = Vec::new(); let mut shift = F::MULTIPLICATIVE_GROUP_GENERATOR; diff --git a/plonky2/src/fri/verifier.rs b/plonky2/src/fri/verifier.rs index 49cfa053..2607ab0d 100644 --- a/plonky2/src/fri/verifier.rs +++ b/plonky2/src/fri/verifier.rs @@ -56,18 +56,17 @@ pub(crate) fn fri_verify_proof_of_work, const D: us Ok(()) } -pub fn verify_fri_proof< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, ->( +pub fn verify_fri_proof, C: GenericConfig, const D: usize>( instance: &FriInstanceInfo, openings: &FriOpenings, challenges: &FriChallenges, initial_merkle_caps: &[MerkleCap], proof: &FriProof, params: &FriParams, -) -> Result<()> { +) -> Result<()> +where + [(); C::Hasher::HASH_SIZE]:, +{ ensure!( params.final_poly_len() == proof.final_poly.len(), "Final polynomial has wrong degree." @@ -112,7 +111,10 @@ fn fri_verify_initial_proof>( x_index: usize, proof: &FriInitialTreeProof, initial_merkle_caps: &[MerkleCap], -) -> Result<()> { +) -> Result<()> +where + [(); H::HASH_SIZE]:, +{ for ((evals, merkle_proof), cap) in proof.evals_proofs.iter().zip(initial_merkle_caps) { verify_merkle_proof::(evals.clone(), x_index, cap, merkle_proof)?; } @@ -177,7 +179,10 @@ fn fri_verifier_query_round< n: usize, round_proof: &FriQueryRound, params: &FriParams, -) -> Result<()> { +) -> Result<()> +where + [(); C::Hasher::HASH_SIZE]:, +{ fri_verify_initial_proof::( x_index, &round_proof.initial_trees_proof, diff --git a/plonky2/src/gates/gate_testing.rs b/plonky2/src/gates/gate_testing.rs index ea1ef9a4..51768ba8 100644 --- a/plonky2/src/gates/gate_testing.rs +++ b/plonky2/src/gates/gate_testing.rs @@ -10,7 +10,7 @@ use crate::hash::hash_types::RichField; use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::CircuitConfig; -use crate::plonk::config::GenericConfig; +use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBaseBatch}; use crate::plonk::verifier::verify; use crate::util::transpose; @@ -92,7 +92,10 @@ pub fn test_eval_fns< const D: usize, >( gate: G, -) -> Result<()> { +) -> Result<()> +where + [(); C::Hasher::HASH_SIZE]:, +{ // Test that `eval_unfiltered` and `eval_unfiltered_base` are coherent. let wires_base = F::rand_vec(gate.num_wires()); let constants_base = F::rand_vec(gate.num_constants()); diff --git a/plonky2/src/hash/merkle_proofs.rs b/plonky2/src/hash/merkle_proofs.rs index 7ef81570..c3ebf406 100644 --- a/plonky2/src/hash/merkle_proofs.rs +++ b/plonky2/src/hash/merkle_proofs.rs @@ -30,7 +30,10 @@ pub(crate) fn verify_merkle_proof>( leaf_index: usize, merkle_cap: &MerkleCap, proof: &MerkleProof, -) -> Result<()> { +) -> Result<()> +where + [(); H::HASH_SIZE]:, +{ let mut index = leaf_index; let mut current_digest = H::hash_or_noop(&leaf_data); for &sibling_digest in proof.siblings.iter() { diff --git a/plonky2/src/hash/merkle_tree.rs b/plonky2/src/hash/merkle_tree.rs index f9890aa5..5fbc441c 100644 --- a/plonky2/src/hash/merkle_tree.rs +++ b/plonky2/src/hash/merkle_tree.rs @@ -60,7 +60,10 @@ fn capacity_up_to_mut(v: &mut Vec, len: usize) -> &mut [MaybeUninit] { fn fill_subtree>( digests_buf: &mut [MaybeUninit], leaves: &[Vec], -) -> H::Hash { +) -> H::Hash +where + [(); H::HASH_SIZE]:, +{ assert_eq!(leaves.len(), digests_buf.len() / 2 + 1); if digests_buf.is_empty() { H::hash_or_noop(&leaves[0]) @@ -89,7 +92,9 @@ fn fill_digests_buf>( cap_buf: &mut [MaybeUninit], leaves: &[Vec], cap_height: usize, -) { +) where + [(); H::HASH_SIZE]:, +{ // Special case of a tree that's all cap. The usual case will panic because we'll try to split // an empty slice into chunks of `0`. (We would not need this if there was a way to split into // `blah` chunks as opposed to chunks _of_ `blah`.) @@ -121,7 +126,10 @@ fn fill_digests_buf>( } impl> MerkleTree { - pub fn new(leaves: Vec>, cap_height: usize) -> Self { + pub fn new(leaves: Vec>, cap_height: usize) -> Self + where + [(); H::HASH_SIZE]:, + { let log2_leaves_len = log2_strict(leaves.len()); assert!( cap_height <= log2_leaves_len, @@ -208,14 +216,13 @@ mod tests { (0..n).map(|_| F::rand_vec(k)).collect() } - fn verify_all_leaves< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( + fn verify_all_leaves, C: GenericConfig, const D: usize>( leaves: Vec>, cap_height: usize, - ) -> Result<()> { + ) -> Result<()> + where + [(); C::Hasher::HASH_SIZE]:, + { let tree = MerkleTree::::new(leaves.clone(), cap_height); for (i, leaf) in leaves.into_iter().enumerate() { let proof = tree.prove(i); diff --git a/plonky2/src/hash/path_compression.rs b/plonky2/src/hash/path_compression.rs index fe7850f4..6dae3d94 100644 --- a/plonky2/src/hash/path_compression.rs +++ b/plonky2/src/hash/path_compression.rs @@ -57,7 +57,10 @@ pub(crate) fn decompress_merkle_proofs>( compressed_proofs: &[MerkleProof], height: usize, cap_height: usize, -) -> Vec> { +) -> Vec> +where + [(); H::HASH_SIZE]:, +{ let num_leaves = 1 << height; let compressed_proofs = compressed_proofs.to_vec(); let mut decompressed_proofs = Vec::with_capacity(compressed_proofs.len()); diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index cf89bf1a..7811c0db 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -610,7 +610,10 @@ impl, const D: usize> CircuitBuilder { } /// Builds a "full circuit", with both prover and verifier data. - pub fn build>(mut self) -> CircuitData { + pub fn build>(mut self) -> CircuitData + where + [(); C::Hasher::HASH_SIZE]:, + { let mut timing = TimingTree::new("preprocess", Level::Trace); let start = Instant::now(); let rate_bits = self.config.fri_config.rate_bits; @@ -776,7 +779,10 @@ impl, const D: usize> CircuitBuilder { } /// Builds a "prover circuit", with data needed to generate proofs but not verify them. - pub fn build_prover>(self) -> ProverCircuitData { + pub fn build_prover>(self) -> ProverCircuitData + where + [(); C::Hasher::HASH_SIZE]:, + { // TODO: Can skip parts of this. let CircuitData { prover_only, @@ -790,7 +796,10 @@ impl, const D: usize> CircuitBuilder { } /// Builds a "verifier circuit", with data needed to verify proofs but not generate them. - pub fn build_verifier>(self) -> VerifierCircuitData { + pub fn build_verifier>(self) -> VerifierCircuitData + where + [(); C::Hasher::HASH_SIZE]:, + { // TODO: Can skip parts of this. let CircuitData { verifier_only, diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index 7e667b8d..3d4ee2df 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -104,7 +104,10 @@ pub struct CircuitData, C: GenericConfig, impl, C: GenericConfig, const D: usize> CircuitData { - pub fn prove(&self, inputs: PartialWitness) -> Result> { + pub fn prove(&self, inputs: PartialWitness) -> Result> + where + [(); C::Hasher::HASH_SIZE]:, + { prove( &self.prover_only, &self.common, @@ -113,14 +116,20 @@ impl, C: GenericConfig, const D: usize> ) } - pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> { + pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> + where + [(); C::Hasher::HASH_SIZE]:, + { verify(proof_with_pis, &self.verifier_only, &self.common) } pub fn verify_compressed( &self, compressed_proof_with_pis: CompressedProofWithPublicInputs, - ) -> Result<()> { + ) -> Result<()> + where + [(); C::Hasher::HASH_SIZE]:, + { compressed_proof_with_pis.verify(&self.verifier_only, &self.common) } } @@ -144,7 +153,10 @@ pub struct ProverCircuitData< impl, C: GenericConfig, const D: usize> ProverCircuitData { - pub fn prove(&self, inputs: PartialWitness) -> Result> { + pub fn prove(&self, inputs: PartialWitness) -> Result> + where + [(); C::Hasher::HASH_SIZE]:, + { prove( &self.prover_only, &self.common, @@ -168,14 +180,20 @@ pub struct VerifierCircuitData< impl, C: GenericConfig, const D: usize> VerifierCircuitData { - pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> { + pub fn verify(&self, proof_with_pis: ProofWithPublicInputs) -> Result<()> + where + [(); C::Hasher::HASH_SIZE]:, + { verify(proof_with_pis, &self.verifier_only, &self.common) } pub fn verify_compressed( &self, compressed_proof_with_pis: CompressedProofWithPublicInputs, - ) -> Result<()> { + ) -> Result<()> + where + [(); C::Hasher::HASH_SIZE]:, + { compressed_proof_with_pis.verify(&self.verifier_only, &self.common) } } diff --git a/plonky2/src/plonk/config.rs b/plonky2/src/plonk/config.rs index 40179c38..cb6d9a9b 100644 --- a/plonky2/src/plonk/config.rs +++ b/plonky2/src/plonk/config.rs @@ -3,7 +3,6 @@ use std::fmt::Debug; use plonky2_field::extension_field::quadratic::QuadraticExtension; use plonky2_field::extension_field::{Extendable, FieldExtension}; use plonky2_field::goldilocks_field::GoldilocksField; -use plonky2_util::ceil_div_usize; use serde::{de::DeserializeOwned, Serialize}; use crate::hash::hash_types::HashOut; @@ -49,13 +48,16 @@ pub trait Hasher: Sized + Clone + Debug + Eq + PartialEq { /// Hash the slice if necessary to reduce its length to ~256 bits. If it already fits, this is a /// no-op. - fn hash_or_noop(inputs: &[F]) -> Self::Hash { - if inputs.len() * ceil_div_usize(F::BITS, 8) <= Self::HASH_SIZE { - let mut inputs_bytes = inputs - .iter() - .flat_map(|x| x.to_canonical_u64().to_le_bytes()) - .collect::>(); - inputs_bytes.resize(Self::HASH_SIZE, 0); + fn hash_or_noop(inputs: &[F]) -> Self::Hash + where + [(); Self::HASH_SIZE]:, + { + if inputs.len() <= 4 { + let mut inputs_bytes = [0u8; Self::HASH_SIZE]; + for i in 0..inputs.len() { + inputs_bytes[i * 8..(i + 1) * 8] + .copy_from_slice(&inputs[i].to_canonical_u64().to_le_bytes()); + } Self::Hash::from_bytes(&inputs_bytes) } else { Self::hash_no_pad(inputs) diff --git a/plonky2/src/plonk/proof.rs b/plonky2/src/plonk/proof.rs index 3de608d4..145ef694 100644 --- a/plonky2/src/plonk/proof.rs +++ b/plonky2/src/plonk/proof.rs @@ -138,7 +138,10 @@ impl, C: GenericConfig, const D: usize> challenges: &ProofChallenges, fri_inferred_elements: FriInferredElements, params: &FriParams, - ) -> Proof { + ) -> Proof + where + [(); C::Hasher::HASH_SIZE]:, + { let CompressedProof { wires_cap, plonk_zs_partial_products_cap, @@ -174,7 +177,10 @@ impl, C: GenericConfig, const D: usize> pub fn decompress( self, common_data: &CommonCircuitData, - ) -> anyhow::Result> { + ) -> anyhow::Result> + where + [(); C::Hasher::HASH_SIZE]:, + { let challenges = self.get_challenges(self.get_public_inputs_hash(), common_data)?; let fri_inferred_elements = self.get_inferred_elements(&challenges, common_data); let decompressed_proof = @@ -190,7 +196,10 @@ impl, C: GenericConfig, const D: usize> self, verifier_data: &VerifierOnlyCircuitData, common_data: &CommonCircuitData, - ) -> anyhow::Result<()> { + ) -> anyhow::Result<()> + where + [(); C::Hasher::HASH_SIZE]:, + { ensure!( self.public_inputs.len() == common_data.num_public_inputs, "Number of public inputs doesn't match circuit data." diff --git a/plonky2/src/plonk/prover.rs b/plonky2/src/plonk/prover.rs index d49014f0..1d99b60a 100644 --- a/plonky2/src/plonk/prover.rs +++ b/plonky2/src/plonk/prover.rs @@ -31,7 +31,10 @@ pub(crate) fn prove, C: GenericConfig, co common_data: &CommonCircuitData, inputs: PartialWitness, timing: &mut TimingTree, -) -> Result> { +) -> Result> +where + [(); C::Hasher::HASH_SIZE]:, +{ let config = &common_data.config; let num_challenges = config.num_challenges; let quotient_degree = common_data.quotient_degree(); diff --git a/plonky2/src/plonk/recursive_verifier.rs b/plonky2/src/plonk/recursive_verifier.rs index c91cbba2..6210bb29 100644 --- a/plonky2/src/plonk/recursive_verifier.rs +++ b/plonky2/src/plonk/recursive_verifier.rs @@ -187,7 +187,9 @@ mod tests { use crate::gates::noop::NoopGate; use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_data::{CircuitConfig, VerifierOnlyCircuitData}; - use crate::plonk::config::{GenericConfig, KeccakGoldilocksConfig, PoseidonGoldilocksConfig}; + use crate::plonk::config::{ + GenericConfig, Hasher, KeccakGoldilocksConfig, PoseidonGoldilocksConfig, + }; use crate::plonk::proof::{CompressedProofWithPublicInputs, ProofWithPublicInputs}; use crate::plonk::prover::prove; use crate::util::timing::TimingTree; @@ -322,7 +324,10 @@ mod tests { ProofWithPublicInputs, VerifierOnlyCircuitData, CommonCircuitData, - )> { + )> + where + [(); C::Hasher::HASH_SIZE]:, + { let mut builder = CircuitBuilder::::new(config.clone()); for _ in 0..num_dummy_gates { builder.add_gate(NoopGate, vec![]); @@ -356,6 +361,7 @@ mod tests { )> where InnerC::Hasher: AlgebraicHasher, + [(); C::Hasher::HASH_SIZE]:, { let mut builder = CircuitBuilder::::new(config.clone()); let mut pw = PartialWitness::new(); @@ -407,7 +413,10 @@ mod tests { >( proof: &ProofWithPublicInputs, cd: &CommonCircuitData, - ) -> Result<()> { + ) -> Result<()> + where + [(); C::Hasher::HASH_SIZE]:, + { let proof_bytes = proof.to_bytes()?; info!("Proof length: {} bytes", proof_bytes.len()); let proof_from_bytes = ProofWithPublicInputs::from_bytes(proof_bytes, cd)?; diff --git a/plonky2/src/plonk/verifier.rs b/plonky2/src/plonk/verifier.rs index 5d69dcb1..ee0e976f 100644 --- a/plonky2/src/plonk/verifier.rs +++ b/plonky2/src/plonk/verifier.rs @@ -15,7 +15,10 @@ pub(crate) fn verify, C: GenericConfig, c proof_with_pis: ProofWithPublicInputs, verifier_data: &VerifierOnlyCircuitData, common_data: &CommonCircuitData, -) -> Result<()> { +) -> Result<()> +where + [(); C::Hasher::HASH_SIZE]:, +{ ensure!( proof_with_pis.public_inputs.len() == common_data.num_public_inputs, "Number of public inputs doesn't match circuit data." @@ -42,7 +45,10 @@ pub(crate) fn verify_with_challenges< challenges: ProofChallenges, verifier_data: &VerifierOnlyCircuitData, common_data: &CommonCircuitData, -) -> Result<()> { +) -> Result<()> +where + [(); C::Hasher::HASH_SIZE]:, +{ let local_constants = &proof.openings.constants; let local_wires = &proof.openings.wires; let vars = EvaluationVars { diff --git a/starky/src/prover.rs b/starky/src/prover.rs index de97ecce..e88aa619 100644 --- a/starky/src/prover.rs +++ b/starky/src/prover.rs @@ -7,7 +7,7 @@ use plonky2::field::zero_poly_coset::ZeroPolyOnCoset; use plonky2::fri::oracle::PolynomialBatch; use plonky2::hash::hash_types::RichField; use plonky2::iop::challenger::Challenger; -use plonky2::plonk::config::GenericConfig; +use plonky2::plonk::config::{GenericConfig, Hasher}; use plonky2::timed; use plonky2::util::timing::TimingTree; use plonky2::util::transpose; @@ -33,6 +33,7 @@ where S: Stark, [(); S::COLUMNS]:, [(); S::PUBLIC_INPUTS]:, + [(); C::Hasher::HASH_SIZE]:, { let degree = trace.len(); let degree_bits = log2_strict(degree); diff --git a/starky/src/verifier.rs b/starky/src/verifier.rs index 91a51bed..8bf1faab 100644 --- a/starky/src/verifier.rs +++ b/starky/src/verifier.rs @@ -3,7 +3,7 @@ use plonky2::field::extension_field::{Extendable, FieldExtension}; use plonky2::field::field_types::Field; use plonky2::fri::verifier::verify_fri_proof; use plonky2::hash::hash_types::RichField; -use plonky2::plonk::config::GenericConfig; +use plonky2::plonk::config::{GenericConfig, Hasher}; use plonky2::plonk::plonk_common::reduce_with_powers; use plonky2_util::log2_strict; @@ -26,6 +26,7 @@ pub fn verify< where [(); S::COLUMNS]:, [(); S::PUBLIC_INPUTS]:, + [(); C::Hasher::HASH_SIZE]:, { let degree_bits = log2_strict(recover_degree(&proof_with_pis.proof, config)); let challenges = proof_with_pis.get_challenges(config, degree_bits)?; @@ -47,6 +48,7 @@ pub(crate) fn verify_with_challenges< where [(); S::COLUMNS]:, [(); S::PUBLIC_INPUTS]:, + [(); C::Hasher::HASH_SIZE]:, { let StarkProofWithPublicInputs { proof, From 55ca718a777fcdc98b54a37a8fd512b9efc5d022 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Sun, 13 Feb 2022 10:51:27 -0800 Subject: [PATCH 12/15] Test no longer ignored --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1db24c69..4dbd5906 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ in the Plonky2 directory. To see recursion performance, one can run this test, which generates a chain of three recursion proofs: ```sh -RUST_LOG=debug RUSTFLAGS=-Ctarget-cpu=native cargo test --release test_recursive_recursive_verifier -- --ignored +RUST_LOG=debug RUSTFLAGS=-Ctarget-cpu=native cargo test --release test_recursive_recursive_verifier ``` From c9171517a4ed57ca41c4cf831af09211e92d88d8 Mon Sep 17 00:00:00 2001 From: BGluth Date: Mon, 14 Feb 2022 10:53:20 -0700 Subject: [PATCH 13/15] Derived more traits for ecdsa types --- plonky2/src/curve/curve_types.rs | 3 ++- plonky2/src/curve/ecdsa.rs | 8 +++++--- plonky2/src/curve/secp256k1.rs | 3 ++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/plonky2/src/curve/curve_types.rs b/plonky2/src/curve/curve_types.rs index 0a9e8711..15f80bc6 100644 --- a/plonky2/src/curve/curve_types.rs +++ b/plonky2/src/curve/curve_types.rs @@ -3,6 +3,7 @@ use std::ops::Neg; use plonky2_field::field_types::{Field, PrimeField}; use plonky2_field::ops::Square; +use serde::{Deserialize, Serialize}; // To avoid implementation conflicts from associated types, // see https://github.com/rust-lang/rust/issues/20400 @@ -36,7 +37,7 @@ pub trait Curve: 'static + Sync + Sized + Copy + Debug { } /// A point on a short Weierstrass curve, represented in affine coordinates. -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Deserialize, Serialize)] pub struct AffinePoint { pub x: C::BaseField, pub y: C::BaseField, diff --git a/plonky2/src/curve/ecdsa.rs b/plonky2/src/curve/ecdsa.rs index 3a5d3c7a..11e05535 100644 --- a/plonky2/src/curve/ecdsa.rs +++ b/plonky2/src/curve/ecdsa.rs @@ -1,17 +1,19 @@ +use serde::{Deserialize, Serialize}; + use crate::curve::curve_msm::msm_parallel; use crate::curve::curve_types::{base_to_scalar, AffinePoint, Curve, CurveScalar}; use crate::field::field_types::Field; -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] pub struct ECDSASignature { pub r: C::ScalarField, pub s: C::ScalarField, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] pub struct ECDSASecretKey(pub C::ScalarField); -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] pub struct ECDSAPublicKey(pub AffinePoint); pub fn sign_message(msg: C::ScalarField, sk: ECDSASecretKey) -> ECDSASignature { diff --git a/plonky2/src/curve/secp256k1.rs b/plonky2/src/curve/secp256k1.rs index 6a460735..18040dae 100644 --- a/plonky2/src/curve/secp256k1.rs +++ b/plonky2/src/curve/secp256k1.rs @@ -1,10 +1,11 @@ use plonky2_field::field_types::Field; use plonky2_field::secp256k1_base::Secp256K1Base; use plonky2_field::secp256k1_scalar::Secp256K1Scalar; +use serde::{Deserialize, Serialize}; use crate::curve::curve_types::{AffinePoint, Curve}; -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)] pub struct Secp256K1; impl Curve for Secp256K1 { From 1467732616868f43055f44012096477933959ec5 Mon Sep 17 00:00:00 2001 From: BGluth Date: Mon, 14 Feb 2022 12:41:24 -0700 Subject: [PATCH 14/15] Impled `Hash` for `AffinePoint` --- plonky2/src/curve/curve_types.rs | 12 ++++++++++++ plonky2/src/curve/ecdsa.rs | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/plonky2/src/curve/curve_types.rs b/plonky2/src/curve/curve_types.rs index 15f80bc6..264120c7 100644 --- a/plonky2/src/curve/curve_types.rs +++ b/plonky2/src/curve/curve_types.rs @@ -1,4 +1,5 @@ use std::fmt::Debug; +use std::hash::Hash; use std::ops::Neg; use plonky2_field::field_types::{Field, PrimeField}; @@ -120,6 +121,17 @@ impl PartialEq for AffinePoint { impl Eq for AffinePoint {} +impl Hash for AffinePoint { + fn hash(&self, state: &mut H) { + if self.zero { + self.zero.hash(state); + } else { + self.x.hash(state); + self.y.hash(state); + } + } +} + /// A point on a short Weierstrass curve, represented in projective coordinates. #[derive(Copy, Clone, Debug)] pub struct ProjectivePoint { diff --git a/plonky2/src/curve/ecdsa.rs b/plonky2/src/curve/ecdsa.rs index 11e05535..cabe038a 100644 --- a/plonky2/src/curve/ecdsa.rs +++ b/plonky2/src/curve/ecdsa.rs @@ -13,7 +13,7 @@ pub struct ECDSASignature { #[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] pub struct ECDSASecretKey(pub C::ScalarField); -#[derive(Copy, Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] +#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] pub struct ECDSAPublicKey(pub AffinePoint); pub fn sign_message(msg: C::ScalarField, sk: ECDSASecretKey) -> ECDSASignature { From 8d699edf21a1e7276aa465df0a88595b6df1656b Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Mon, 14 Feb 2022 13:47:33 -0800 Subject: [PATCH 15/15] Move some methods outside `impl System` (#484) I didn't really have a good reason for putting there; seems more idiomatic to make them global since they don't need `self`/`Self`. --- system_zero/src/arithmetic/addition.rs | 4 +- system_zero/src/arithmetic/division.rs | 4 +- system_zero/src/arithmetic/mod.rs | 4 +- system_zero/src/arithmetic/multiplication.rs | 4 +- system_zero/src/arithmetic/subtraction.rs | 4 +- system_zero/src/core_registers.rs | 170 +++++---- system_zero/src/permutation_unit.rs | 354 +++++++++---------- system_zero/src/system_zero.rs | 24 +- 8 files changed, 282 insertions(+), 286 deletions(-) diff --git a/system_zero/src/arithmetic/addition.rs b/system_zero/src/arithmetic/addition.rs index 653d533b..7aa0d81a 100644 --- a/system_zero/src/arithmetic/addition.rs +++ b/system_zero/src/arithmetic/addition.rs @@ -1,5 +1,5 @@ use plonky2::field::extension_field::Extendable; -use plonky2::field::field_types::Field; +use plonky2::field::field_types::{Field, PrimeField64}; use plonky2::field::packed_field::PackedField; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; @@ -10,7 +10,7 @@ use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsume use crate::registers::arithmetic::*; use crate::registers::NUM_COLUMNS; -pub(crate) fn generate_addition(values: &mut [F; NUM_COLUMNS]) { +pub(crate) fn generate_addition(values: &mut [F; NUM_COLUMNS]) { let in_1 = values[COL_ADD_INPUT_1].to_canonical_u64(); let in_2 = values[COL_ADD_INPUT_2].to_canonical_u64(); let in_3 = values[COL_ADD_INPUT_3].to_canonical_u64(); diff --git a/system_zero/src/arithmetic/division.rs b/system_zero/src/arithmetic/division.rs index 2f15b233..e91288b9 100644 --- a/system_zero/src/arithmetic/division.rs +++ b/system_zero/src/arithmetic/division.rs @@ -1,5 +1,5 @@ use plonky2::field::extension_field::Extendable; -use plonky2::field::field_types::Field; +use plonky2::field::field_types::{Field, PrimeField64}; use plonky2::field::packed_field::PackedField; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; @@ -9,7 +9,7 @@ use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsume use crate::registers::arithmetic::*; use crate::registers::NUM_COLUMNS; -pub(crate) fn generate_division(values: &mut [F; NUM_COLUMNS]) { +pub(crate) fn generate_division(values: &mut [F; NUM_COLUMNS]) { // TODO } diff --git a/system_zero/src/arithmetic/mod.rs b/system_zero/src/arithmetic/mod.rs index 45a9f7d9..a2b3a4f8 100644 --- a/system_zero/src/arithmetic/mod.rs +++ b/system_zero/src/arithmetic/mod.rs @@ -1,5 +1,5 @@ use plonky2::field::extension_field::Extendable; -use plonky2::field::field_types::Field; +use plonky2::field::field_types::{Field, PrimeField64}; use plonky2::field::packed_field::PackedField; use plonky2::hash::hash_types::RichField; use plonky2::plonk::circuit_builder::CircuitBuilder; @@ -24,7 +24,7 @@ mod division; mod multiplication; mod subtraction; -pub(crate) fn generate_arithmetic_unit(values: &mut [F; NUM_COLUMNS]) { +pub(crate) fn generate_arithmetic_unit(values: &mut [F; NUM_COLUMNS]) { if values[IS_ADD].is_one() { generate_addition(values); } else if values[IS_SUB].is_one() { diff --git a/system_zero/src/arithmetic/multiplication.rs b/system_zero/src/arithmetic/multiplication.rs index 2eefad38..70c181d8 100644 --- a/system_zero/src/arithmetic/multiplication.rs +++ b/system_zero/src/arithmetic/multiplication.rs @@ -1,5 +1,5 @@ use plonky2::field::extension_field::Extendable; -use plonky2::field::field_types::Field; +use plonky2::field::field_types::{Field, PrimeField64}; use plonky2::field::packed_field::PackedField; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; @@ -9,7 +9,7 @@ use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsume use crate::registers::arithmetic::*; use crate::registers::NUM_COLUMNS; -pub(crate) fn generate_multiplication(values: &mut [F; NUM_COLUMNS]) { +pub(crate) fn generate_multiplication(values: &mut [F; NUM_COLUMNS]) { // TODO } diff --git a/system_zero/src/arithmetic/subtraction.rs b/system_zero/src/arithmetic/subtraction.rs index 3613dee6..267bac72 100644 --- a/system_zero/src/arithmetic/subtraction.rs +++ b/system_zero/src/arithmetic/subtraction.rs @@ -1,5 +1,5 @@ use plonky2::field::extension_field::Extendable; -use plonky2::field::field_types::Field; +use plonky2::field::field_types::{Field, PrimeField64}; use plonky2::field::packed_field::PackedField; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; @@ -9,7 +9,7 @@ use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsume use crate::registers::arithmetic::*; use crate::registers::NUM_COLUMNS; -pub(crate) fn generate_subtraction(values: &mut [F; NUM_COLUMNS]) { +pub(crate) fn generate_subtraction(values: &mut [F; NUM_COLUMNS]) { // TODO } diff --git a/system_zero/src/core_registers.rs b/system_zero/src/core_registers.rs index 03e7fa04..c8c6533b 100644 --- a/system_zero/src/core_registers.rs +++ b/system_zero/src/core_registers.rs @@ -1,4 +1,5 @@ -use plonky2::field::extension_field::{Extendable, FieldExtension}; +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::{Field, PrimeField64}; use plonky2::field::packed_field::PackedField; use plonky2::hash::hash_types::RichField; use plonky2::plonk::circuit_builder::CircuitBuilder; @@ -9,93 +10,84 @@ use starky::vars::StarkEvaluationVars; use crate::public_input_layout::NUM_PUBLIC_INPUTS; use crate::registers::core::*; use crate::registers::NUM_COLUMNS; -use crate::system_zero::SystemZero; -impl, const D: usize> SystemZero { - pub(crate) fn generate_first_row_core_registers(&self, first_values: &mut [F; NUM_COLUMNS]) { - first_values[COL_CLOCK] = F::ZERO; - first_values[COL_RANGE_16] = F::ZERO; - first_values[COL_INSTRUCTION_PTR] = F::ZERO; - first_values[COL_FRAME_PTR] = F::ZERO; - first_values[COL_STACK_PTR] = F::ZERO; - } - - pub(crate) fn generate_next_row_core_registers( - &self, - local_values: &[F; NUM_COLUMNS], - next_values: &mut [F; NUM_COLUMNS], - ) { - // We increment the clock by 1. - next_values[COL_CLOCK] = local_values[COL_CLOCK] + F::ONE; - - // We increment the 16-bit table by 1, unless we've reached the max value of 2^16 - 1, in - // which case we repeat that value. - let prev_range_16 = local_values[COL_RANGE_16].to_canonical_u64(); - let next_range_16 = (prev_range_16 + 1).min((1 << 16) - 1); - next_values[COL_RANGE_16] = F::from_canonical_u64(next_range_16); - - // next_values[COL_INSTRUCTION_PTR] = todo!(); - - // next_values[COL_FRAME_PTR] = todo!(); - - // next_values[COL_STACK_PTR] = todo!(); - } - - #[inline] - pub(crate) fn eval_core_registers( - &self, - vars: StarkEvaluationVars, - yield_constr: &mut ConstraintConsumer

, - ) where - FE: FieldExtension, - P: PackedField, - { - // The clock must start with 0, and increment by 1. - let local_clock = vars.local_values[COL_CLOCK]; - let next_clock = vars.next_values[COL_CLOCK]; - let delta_clock = next_clock - local_clock; - yield_constr.constraint_first_row(local_clock); - yield_constr.constraint(delta_clock - FE::ONE); - - // The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1. - let local_range_16 = vars.local_values[COL_RANGE_16]; - let next_range_16 = vars.next_values[COL_RANGE_16]; - let delta_range_16 = next_range_16 - local_range_16; - yield_constr.constraint_first_row(local_range_16); - yield_constr.constraint_last_row(local_range_16 - FE::from_canonical_u64((1 << 16) - 1)); - yield_constr.constraint(delta_range_16 * delta_range_16 - delta_range_16); - - // TODO constraints for stack etc. - } - - pub(crate) fn eval_core_registers_recursively( - &self, - builder: &mut CircuitBuilder, - vars: StarkEvaluationTargets, - yield_constr: &mut RecursiveConstraintConsumer, - ) { - let one_ext = builder.one_extension(); - let max_u16 = builder.constant(F::from_canonical_u64((1 << 16) - 1)); - let max_u16_ext = builder.convert_to_ext(max_u16); - - // The clock must start with 0, and increment by 1. - let local_clock = vars.local_values[COL_CLOCK]; - let next_clock = vars.next_values[COL_CLOCK]; - let delta_clock = builder.sub_extension(next_clock, local_clock); - yield_constr.constraint_first_row(builder, local_clock); - let constraint = builder.sub_extension(delta_clock, one_ext); - yield_constr.constraint(builder, constraint); - - // The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1. - let local_range_16 = vars.local_values[COL_RANGE_16]; - let next_range_16 = vars.next_values[COL_RANGE_16]; - let delta_range_16 = builder.sub_extension(next_range_16, local_range_16); - yield_constr.constraint_first_row(builder, local_range_16); - let constraint = builder.sub_extension(local_range_16, max_u16_ext); - yield_constr.constraint_last_row(builder, constraint); - let constraint = builder.mul_add_extension(delta_range_16, delta_range_16, delta_range_16); - yield_constr.constraint(builder, constraint); - - // TODO constraints for stack etc. - } +pub(crate) fn generate_first_row_core_registers(first_values: &mut [F; NUM_COLUMNS]) { + first_values[COL_CLOCK] = F::ZERO; + first_values[COL_RANGE_16] = F::ZERO; + first_values[COL_INSTRUCTION_PTR] = F::ZERO; + first_values[COL_FRAME_PTR] = F::ZERO; + first_values[COL_STACK_PTR] = F::ZERO; +} + +pub(crate) fn generate_next_row_core_registers( + local_values: &[F; NUM_COLUMNS], + next_values: &mut [F; NUM_COLUMNS], +) { + // We increment the clock by 1. + next_values[COL_CLOCK] = local_values[COL_CLOCK] + F::ONE; + + // We increment the 16-bit table by 1, unless we've reached the max value of 2^16 - 1, in + // which case we repeat that value. + let prev_range_16 = local_values[COL_RANGE_16].to_canonical_u64(); + let next_range_16 = (prev_range_16 + 1).min((1 << 16) - 1); + next_values[COL_RANGE_16] = F::from_canonical_u64(next_range_16); + + // next_values[COL_INSTRUCTION_PTR] = todo!(); + + // next_values[COL_FRAME_PTR] = todo!(); + + // next_values[COL_STACK_PTR] = todo!(); +} + +#[inline] +pub(crate) fn eval_core_registers>( + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, +) { + // The clock must start with 0, and increment by 1. + let local_clock = vars.local_values[COL_CLOCK]; + let next_clock = vars.next_values[COL_CLOCK]; + let delta_clock = next_clock - local_clock; + yield_constr.constraint_first_row(local_clock); + yield_constr.constraint(delta_clock - F::ONE); + + // The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1. + let local_range_16 = vars.local_values[COL_RANGE_16]; + let next_range_16 = vars.next_values[COL_RANGE_16]; + let delta_range_16 = next_range_16 - local_range_16; + yield_constr.constraint_first_row(local_range_16); + yield_constr.constraint_last_row(local_range_16 - F::from_canonical_u64((1 << 16) - 1)); + yield_constr.constraint(delta_range_16 * delta_range_16 - delta_range_16); + + // TODO constraints for stack etc. +} + +pub(crate) fn eval_core_registers_recursively, const D: usize>( + builder: &mut CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let one_ext = builder.one_extension(); + let max_u16 = builder.constant(F::from_canonical_u64((1 << 16) - 1)); + let max_u16_ext = builder.convert_to_ext(max_u16); + + // The clock must start with 0, and increment by 1. + let local_clock = vars.local_values[COL_CLOCK]; + let next_clock = vars.next_values[COL_CLOCK]; + let delta_clock = builder.sub_extension(next_clock, local_clock); + yield_constr.constraint_first_row(builder, local_clock); + let constraint = builder.sub_extension(delta_clock, one_ext); + yield_constr.constraint(builder, constraint); + + // The 16-bit table must start with 0, end with 2^16 - 1, and increment by 0 or 1. + let local_range_16 = vars.local_values[COL_RANGE_16]; + let next_range_16 = vars.next_values[COL_RANGE_16]; + let delta_range_16 = builder.sub_extension(next_range_16, local_range_16); + yield_constr.constraint_first_row(builder, local_range_16); + let constraint = builder.sub_extension(local_range_16, max_u16_ext); + yield_constr.constraint_last_row(builder, constraint); + let constraint = builder.mul_add_extension(delta_range_16, delta_range_16, delta_range_16); + yield_constr.constraint(builder, constraint); + + // TODO constraints for stack etc. } diff --git a/system_zero/src/permutation_unit.rs b/system_zero/src/permutation_unit.rs index 2681f2d9..366cff65 100644 --- a/system_zero/src/permutation_unit.rs +++ b/system_zero/src/permutation_unit.rs @@ -2,7 +2,7 @@ use plonky2::field::extension_field::{Extendable, FieldExtension}; use plonky2::field::packed_field::PackedField; use plonky2::hash::hash_types::RichField; use plonky2::hash::hashing::SPONGE_WIDTH; -use plonky2::hash::poseidon::{HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS}; +use plonky2::hash::poseidon::{Poseidon, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS}; use plonky2::plonk::circuit_builder::CircuitBuilder; use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use starky::vars::StarkEvaluationTargets; @@ -11,15 +11,14 @@ use starky::vars::StarkEvaluationVars; use crate::public_input_layout::NUM_PUBLIC_INPUTS; use crate::registers::permutation::*; use crate::registers::NUM_COLUMNS; -use crate::system_zero::SystemZero; -fn constant_layer( +fn constant_layer( mut state: [P; SPONGE_WIDTH], round: usize, ) -> [P; SPONGE_WIDTH] where - F: RichField, - FE: FieldExtension, + F: Poseidon, + FE: FieldExtension, P: PackedField, { // One day I might actually vectorize this, but today is not that day. @@ -36,10 +35,10 @@ where state } -fn mds_layer(mut state: [P; SPONGE_WIDTH]) -> [P; SPONGE_WIDTH] +fn mds_layer(mut state: [P; SPONGE_WIDTH]) -> [P; SPONGE_WIDTH] where - F: RichField, - FE: FieldExtension, + F: Poseidon, + FE: FieldExtension, P: PackedField, { for i in 0..P::WIDTH { @@ -55,206 +54,204 @@ where state } -impl, const D: usize> SystemZero { - pub(crate) fn generate_permutation_unit(values: &mut [F; NUM_COLUMNS]) { - // Load inputs. - let mut state = [F::ZERO; SPONGE_WIDTH]; +pub(crate) fn generate_permutation_unit(values: &mut [F; NUM_COLUMNS]) { + // Load inputs. + let mut state = [F::ZERO; SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { + state[i] = values[col_input(i)]; + } + + for r in 0..HALF_N_FULL_ROUNDS { + F::constant_layer(&mut state, r); + for i in 0..SPONGE_WIDTH { - state[i] = values[col_input(i)]; + let state_cubed = state[i].cube(); + values[col_full_first_mid_sbox(r, i)] = state_cubed; + state[i] *= state_cubed.square(); // Form state ** 7. } - for r in 0..HALF_N_FULL_ROUNDS { - F::constant_layer(&mut state, r); + state = F::mds_layer(&state); - for i in 0..SPONGE_WIDTH { - let state_cubed = state[i].cube(); - values[col_full_first_mid_sbox(r, i)] = state_cubed; - state[i] *= state_cubed.square(); // Form state ** 7. - } - - state = F::mds_layer(&state); - - for i in 0..SPONGE_WIDTH { - values[col_full_first_after_mds(r, i)] = state[i]; - } - } - - for r in 0..N_PARTIAL_ROUNDS { - F::constant_layer(&mut state, HALF_N_FULL_ROUNDS + r); - - let state0_cubed = state[0].cube(); - values[col_partial_mid_sbox(r)] = state0_cubed; - state[0] *= state0_cubed.square(); // Form state ** 7. - values[col_partial_after_sbox(r)] = state[0]; - - state = F::mds_layer(&state); - } - - for r in 0..HALF_N_FULL_ROUNDS { - F::constant_layer(&mut state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r); - - for i in 0..SPONGE_WIDTH { - let state_cubed = state[i].cube(); - values[col_full_second_mid_sbox(r, i)] = state_cubed; - state[i] *= state_cubed.square(); // Form state ** 7. - } - - state = F::mds_layer(&state); - - for i in 0..SPONGE_WIDTH { - values[col_full_second_after_mds(r, i)] = state[i]; - } + for i in 0..SPONGE_WIDTH { + values[col_full_first_after_mds(r, i)] = state[i]; } } - #[inline] - pub(crate) fn eval_permutation_unit( - vars: StarkEvaluationVars, - yield_constr: &mut ConstraintConsumer

, - ) where - FE: FieldExtension, - P: PackedField, - { - let local_values = &vars.local_values; + for r in 0..N_PARTIAL_ROUNDS { + F::constant_layer(&mut state, HALF_N_FULL_ROUNDS + r); + + let state0_cubed = state[0].cube(); + values[col_partial_mid_sbox(r)] = state0_cubed; + state[0] *= state0_cubed.square(); // Form state ** 7. + values[col_partial_after_sbox(r)] = state[0]; + + state = F::mds_layer(&state); + } + + for r in 0..HALF_N_FULL_ROUNDS { + F::constant_layer(&mut state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r); - // Load inputs. - let mut state = [P::ZEROS; SPONGE_WIDTH]; for i in 0..SPONGE_WIDTH { - state[i] = local_values[col_input(i)]; + let state_cubed = state[i].cube(); + values[col_full_second_mid_sbox(r, i)] = state_cubed; + state[i] *= state_cubed.square(); // Form state ** 7. } - for r in 0..HALF_N_FULL_ROUNDS { - state = constant_layer(state, r); + state = F::mds_layer(&state); - for i in 0..SPONGE_WIDTH { - let state_cubed = state[i] * state[i].square(); - yield_constr - .constraint_wrapping(state_cubed - local_values[col_full_first_mid_sbox(r, i)]); - let state_cubed = local_values[col_full_first_mid_sbox(r, i)]; - state[i] *= state_cubed.square(); // Form state ** 7. - } + for i in 0..SPONGE_WIDTH { + values[col_full_second_after_mds(r, i)] = state[i]; + } + } +} - state = mds_layer(state); +#[inline] +pub(crate) fn eval_permutation_unit( + vars: StarkEvaluationVars, + yield_constr: &mut ConstraintConsumer

, +) where + F: Poseidon, + FE: FieldExtension, + P: PackedField, +{ + let local_values = &vars.local_values; - for i in 0..SPONGE_WIDTH { - yield_constr - .constraint_wrapping(state[i] - local_values[col_full_first_after_mds(r, i)]); - state[i] = local_values[col_full_first_after_mds(r, i)]; - } + // Load inputs. + let mut state = [P::ZEROS; SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { + state[i] = local_values[col_input(i)]; + } + + for r in 0..HALF_N_FULL_ROUNDS { + state = constant_layer(state, r); + + for i in 0..SPONGE_WIDTH { + let state_cubed = state[i] * state[i].square(); + yield_constr + .constraint_wrapping(state_cubed - local_values[col_full_first_mid_sbox(r, i)]); + let state_cubed = local_values[col_full_first_mid_sbox(r, i)]; + state[i] *= state_cubed.square(); // Form state ** 7. } - for r in 0..N_PARTIAL_ROUNDS { - state = constant_layer(state, HALF_N_FULL_ROUNDS + r); + state = mds_layer(state); - let state0_cubed = state[0] * state[0].square(); - yield_constr.constraint_wrapping(state0_cubed - local_values[col_partial_mid_sbox(r)]); - let state0_cubed = local_values[col_partial_mid_sbox(r)]; - state[0] *= state0_cubed.square(); // Form state ** 7. - yield_constr.constraint_wrapping(state[0] - local_values[col_partial_after_sbox(r)]); - state[0] = local_values[col_partial_after_sbox(r)]; - - state = mds_layer(state); - } - - for r in 0..HALF_N_FULL_ROUNDS { - state = constant_layer(state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r); - - for i in 0..SPONGE_WIDTH { - let state_cubed = state[i] * state[i].square(); - yield_constr.constraint_wrapping( - state_cubed - local_values[col_full_second_mid_sbox(r, i)], - ); - let state_cubed = local_values[col_full_second_mid_sbox(r, i)]; - state[i] *= state_cubed.square(); // Form state ** 7. - } - - state = mds_layer(state); - - for i in 0..SPONGE_WIDTH { - yield_constr - .constraint_wrapping(state[i] - local_values[col_full_second_after_mds(r, i)]); - state[i] = local_values[col_full_second_after_mds(r, i)]; - } + for i in 0..SPONGE_WIDTH { + yield_constr + .constraint_wrapping(state[i] - local_values[col_full_first_after_mds(r, i)]); + state[i] = local_values[col_full_first_after_mds(r, i)]; } } - pub(crate) fn eval_permutation_unit_recursively( - builder: &mut CircuitBuilder, - vars: StarkEvaluationTargets, - yield_constr: &mut RecursiveConstraintConsumer, - ) { - let zero = builder.zero_extension(); - let local_values = &vars.local_values; + for r in 0..N_PARTIAL_ROUNDS { + state = constant_layer(state, HALF_N_FULL_ROUNDS + r); + + let state0_cubed = state[0] * state[0].square(); + yield_constr.constraint_wrapping(state0_cubed - local_values[col_partial_mid_sbox(r)]); + let state0_cubed = local_values[col_partial_mid_sbox(r)]; + state[0] *= state0_cubed.square(); // Form state ** 7. + yield_constr.constraint_wrapping(state[0] - local_values[col_partial_after_sbox(r)]); + state[0] = local_values[col_partial_after_sbox(r)]; + + state = mds_layer(state); + } + + for r in 0..HALF_N_FULL_ROUNDS { + state = constant_layer(state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r); - // Load inputs. - let mut state = [zero; SPONGE_WIDTH]; for i in 0..SPONGE_WIDTH { - state[i] = local_values[col_input(i)]; + let state_cubed = state[i] * state[i].square(); + yield_constr + .constraint_wrapping(state_cubed - local_values[col_full_second_mid_sbox(r, i)]); + let state_cubed = local_values[col_full_second_mid_sbox(r, i)]; + state[i] *= state_cubed.square(); // Form state ** 7. } - for r in 0..HALF_N_FULL_ROUNDS { - F::constant_layer_recursive(builder, &mut state, r); + state = mds_layer(state); - for i in 0..SPONGE_WIDTH { - let state_cubed = builder.cube_extension(state[i]); - let diff = - builder.sub_extension(state_cubed, local_values[col_full_first_mid_sbox(r, i)]); - yield_constr.constraint_wrapping(builder, diff); - let state_cubed = local_values[col_full_first_mid_sbox(r, i)]; - state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); - // Form state ** 7. - } - - state = F::mds_layer_recursive(builder, &state); - - for i in 0..SPONGE_WIDTH { - let diff = - builder.sub_extension(state[i], local_values[col_full_first_after_mds(r, i)]); - yield_constr.constraint_wrapping(builder, diff); - state[i] = local_values[col_full_first_after_mds(r, i)]; - } + for i in 0..SPONGE_WIDTH { + yield_constr + .constraint_wrapping(state[i] - local_values[col_full_second_after_mds(r, i)]); + state[i] = local_values[col_full_second_after_mds(r, i)]; } + } +} - for r in 0..N_PARTIAL_ROUNDS { - F::constant_layer_recursive(builder, &mut state, HALF_N_FULL_ROUNDS + r); +pub(crate) fn eval_permutation_unit_recursively, const D: usize>( + builder: &mut CircuitBuilder, + vars: StarkEvaluationTargets, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let zero = builder.zero_extension(); + let local_values = &vars.local_values; - let state0_cubed = builder.cube_extension(state[0]); - let diff = builder.sub_extension(state0_cubed, local_values[col_partial_mid_sbox(r)]); + // Load inputs. + let mut state = [zero; SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { + state[i] = local_values[col_input(i)]; + } + + for r in 0..HALF_N_FULL_ROUNDS { + F::constant_layer_recursive(builder, &mut state, r); + + for i in 0..SPONGE_WIDTH { + let state_cubed = builder.cube_extension(state[i]); + let diff = + builder.sub_extension(state_cubed, local_values[col_full_first_mid_sbox(r, i)]); yield_constr.constraint_wrapping(builder, diff); - let state0_cubed = local_values[col_partial_mid_sbox(r)]; - state[0] = builder.mul_many_extension(&[state[0], state0_cubed, state0_cubed]); // Form state ** 7. - let diff = builder.sub_extension(state[0], local_values[col_partial_after_sbox(r)]); - yield_constr.constraint_wrapping(builder, diff); - state[0] = local_values[col_partial_after_sbox(r)]; - - state = F::mds_layer_recursive(builder, &state); + let state_cubed = local_values[col_full_first_mid_sbox(r, i)]; + state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); + // Form state ** 7. } - for r in 0..HALF_N_FULL_ROUNDS { - F::constant_layer_recursive( - builder, - &mut state, - HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r, - ); + state = F::mds_layer_recursive(builder, &state); - for i in 0..SPONGE_WIDTH { - let state_cubed = builder.cube_extension(state[i]); - let diff = builder - .sub_extension(state_cubed, local_values[col_full_second_mid_sbox(r, i)]); - yield_constr.constraint_wrapping(builder, diff); - let state_cubed = local_values[col_full_second_mid_sbox(r, i)]; - state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); - // Form state ** 7. - } + for i in 0..SPONGE_WIDTH { + let diff = + builder.sub_extension(state[i], local_values[col_full_first_after_mds(r, i)]); + yield_constr.constraint_wrapping(builder, diff); + state[i] = local_values[col_full_first_after_mds(r, i)]; + } + } - state = F::mds_layer_recursive(builder, &state); + for r in 0..N_PARTIAL_ROUNDS { + F::constant_layer_recursive(builder, &mut state, HALF_N_FULL_ROUNDS + r); - for i in 0..SPONGE_WIDTH { - let diff = - builder.sub_extension(state[i], local_values[col_full_second_after_mds(r, i)]); - yield_constr.constraint_wrapping(builder, diff); - state[i] = local_values[col_full_second_after_mds(r, i)]; - } + let state0_cubed = builder.cube_extension(state[0]); + let diff = builder.sub_extension(state0_cubed, local_values[col_partial_mid_sbox(r)]); + yield_constr.constraint_wrapping(builder, diff); + let state0_cubed = local_values[col_partial_mid_sbox(r)]; + state[0] = builder.mul_many_extension(&[state[0], state0_cubed, state0_cubed]); // Form state ** 7. + let diff = builder.sub_extension(state[0], local_values[col_partial_after_sbox(r)]); + yield_constr.constraint_wrapping(builder, diff); + state[0] = local_values[col_partial_after_sbox(r)]; + + state = F::mds_layer_recursive(builder, &state); + } + + for r in 0..HALF_N_FULL_ROUNDS { + F::constant_layer_recursive( + builder, + &mut state, + HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + r, + ); + + for i in 0..SPONGE_WIDTH { + let state_cubed = builder.cube_extension(state[i]); + let diff = + builder.sub_extension(state_cubed, local_values[col_full_second_mid_sbox(r, i)]); + yield_constr.constraint_wrapping(builder, diff); + let state_cubed = local_values[col_full_second_mid_sbox(r, i)]; + state[i] = builder.mul_many_extension(&[state[i], state_cubed, state_cubed]); + // Form state ** 7. + } + + state = F::mds_layer_recursive(builder, &state); + + for i in 0..SPONGE_WIDTH { + let diff = + builder.sub_extension(state[i], local_values[col_full_second_after_mds(r, i)]); + yield_constr.constraint_wrapping(builder, diff); + state[i] = local_values[col_full_second_after_mds(r, i)]; } } } @@ -269,11 +266,10 @@ mod tests { use starky::constraint_consumer::ConstraintConsumer; use starky::vars::StarkEvaluationVars; - use crate::permutation_unit::SPONGE_WIDTH; + use crate::permutation_unit::{eval_permutation_unit, generate_permutation_unit, SPONGE_WIDTH}; use crate::public_input_layout::NUM_PUBLIC_INPUTS; use crate::registers::permutation::{col_input, col_output}; use crate::registers::NUM_COLUMNS; - use crate::system_zero::SystemZero; #[test] fn generate_eval_consistency() { @@ -281,7 +277,7 @@ mod tests { type F = GoldilocksField; let mut values = [F::default(); NUM_COLUMNS]; - SystemZero::::generate_permutation_unit(&mut values); + generate_permutation_unit(&mut values); let vars = StarkEvaluationVars { local_values: &values, @@ -295,7 +291,7 @@ mod tests { GoldilocksField::ONE, GoldilocksField::ONE, ); - SystemZero::::eval_permutation_unit(vars, &mut constrant_consumer); + eval_permutation_unit(vars, &mut constrant_consumer); for &acc in &constrant_consumer.constraint_accs { assert_eq!(acc, GoldilocksField::ZERO); } @@ -318,7 +314,7 @@ mod tests { for i in 0..SPONGE_WIDTH { values[col_input(i)] = state[i]; } - SystemZero::::generate_permutation_unit(&mut values); + generate_permutation_unit(&mut values); let mut result = [F::default(); SPONGE_WIDTH]; for i in 0..SPONGE_WIDTH { result[i] = values[col_output(i)]; diff --git a/system_zero/src/system_zero.rs b/system_zero/src/system_zero.rs index 780b1d38..2eeb4697 100644 --- a/system_zero/src/system_zero.rs +++ b/system_zero/src/system_zero.rs @@ -12,7 +12,14 @@ use starky::vars::StarkEvaluationVars; use crate::arithmetic::{ eval_arithmetic_unit, eval_arithmetic_unit_recursively, generate_arithmetic_unit, }; +use crate::core_registers::{ + eval_core_registers, eval_core_registers_recursively, generate_first_row_core_registers, + generate_next_row_core_registers, +}; use crate::memory::TransactionMemory; +use crate::permutation_unit::{ + eval_permutation_unit, eval_permutation_unit_recursively, generate_permutation_unit, +}; use crate::public_input_layout::NUM_PUBLIC_INPUTS; use crate::registers::NUM_COLUMNS; @@ -29,16 +36,17 @@ impl, const D: usize> SystemZero { let memory = TransactionMemory::default(); let mut row = [F::ZERO; NUM_COLUMNS]; - self.generate_first_row_core_registers(&mut row); - Self::generate_permutation_unit(&mut row); + generate_first_row_core_registers(&mut row); + generate_arithmetic_unit(&mut row); + generate_permutation_unit(&mut row); let mut trace = Vec::with_capacity(MIN_TRACE_ROWS); loop { let mut next_row = [F::ZERO; NUM_COLUMNS]; - self.generate_next_row_core_registers(&row, &mut next_row); + generate_next_row_core_registers(&row, &mut next_row); generate_arithmetic_unit(&mut next_row); - Self::generate_permutation_unit(&mut next_row); + generate_permutation_unit(&mut next_row); trace.push(row); row = next_row; @@ -74,9 +82,9 @@ impl, const D: usize> Stark for SystemZero, P: PackedField, { - self.eval_core_registers(vars, yield_constr); + eval_core_registers(vars, yield_constr); eval_arithmetic_unit(vars, yield_constr); - Self::eval_permutation_unit(vars, yield_constr); + eval_permutation_unit::(vars, yield_constr); // TODO: Other units } @@ -86,9 +94,9 @@ impl, const D: usize> Stark for SystemZero, yield_constr: &mut RecursiveConstraintConsumer, ) { - self.eval_core_registers_recursively(builder, vars, yield_constr); + eval_core_registers_recursively(builder, vars, yield_constr); eval_arithmetic_unit_recursively(builder, vars, yield_constr); - Self::eval_permutation_unit_recursively(builder, vars, yield_constr); + eval_permutation_unit_recursively(builder, vars, yield_constr); // TODO: Other units }